Browse Source

Merge pull request #235 from captainst/master

Add object detection example and some other minor fixes
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
5b6aceea7b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 321 additions and 89 deletions
  1. +1
    -0
      .gitignore
  2. +1
    -0
      README.md
  3. +0
    -6
      TensorFlow.NET.sln
  4. +3
    -2
      docs/source/LogisticRegression.md
  5. +0
    -0
      graph/README.md
  6. BIN
      graph/cond_test.meta
  7. BIN
      graph/kmeans.meta
  8. +23
    -0
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  10. +1
    -19
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  11. +0
    -34
      src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
  12. +21
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  13. +0
    -5
      src/TensorFlowNET.Core/Python.cs
  14. +6
    -6
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  15. +3
    -3
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  17. +13
    -1
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  18. +0
    -4
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  19. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  20. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  21. +0
    -4
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  22. BIN
      tensorflowlib/runtimes/win-x64/native/tensorflow.dll
  23. +169
    -0
      test/TensorFlowNET.Examples/ObjectDetection.cs
  24. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  25. +74
    -0
      test/TensorFlowNET.Examples/Utility/PbtxtParser.cs
  26. +0
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 1
- 0
.gitignore View File

@@ -62,6 +62,7 @@ StyleCopReport.xml
*_p.c
*_i.h
*.ilk
*.meta
*.obj
*.iobj
*.pch


+ 1
- 0
README.md View File

@@ -1,4 +1,5 @@
# TensorFlow.NET
TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.

[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)


+ 0
- 6
TensorFlow.NET.sln View File

@@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DE97EAD7-B92C-4112-9690-91C40A97179E}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -29,10 +27,6 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 3
- 2
docs/source/LogisticRegression.md View File

@@ -23,8 +23,9 @@ logistic回归通过函数S将ax+b对应到一个隐状态p,p = S(ax+b),然
将t换成ax+b,可以得到逻辑回归模型的参数形式:
P(x;a,b) = 1 / (1 + e^(-ax+b))

![image](https://github.com/SciEvan/TensorFlow.NET/blob/master/docs/source/sigmoid.png)
sigmoid函数的图像
![image](https://github.com/SciEvan/TensorFlow.NET/tree/master/docs/source/sigmoid.png)

sigmoid函数的图像

By the function of the function S, we can limit the output value to the interval [0, 1],
p(x) can then be used to represent the probability p(y=1|x), the probability that y is divided into 1 group when an x occurs.


+ 0
- 0
graph/README.md View File


BIN
graph/cond_test.meta View File


BIN
graph/kmeans.meta View File


+ 23
- 0
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -20,9 +20,32 @@ namespace Tensorflow

foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;

if (!_registered_ops.ContainsKey("NearestNeighbors"))
_registered_ops["NearestNeighbors"] = op_NearestNeighbors();
}

return _registered_ops;
}

/// <summary>
/// Doesn't work because the op can't be found on binary
/// </summary>
/// <returns></returns>
private static OpDef op_NearestNeighbors()
{
var def = new OpDef
{
Name = "NearestNeighbors"
};

def.InputArg.Add(new ArgDef { Name = "points", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "centers", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "k", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_indices", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_distances", Type = DataType.DtFloat });

return def;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -335,7 +335,7 @@ namespace Tensorflow.Operations
ret.Enter();
foreach (var nested_def in proto.NestedContexts)
from_control_flow_context_def(nested_def, import_scope: import_scope);
throw new NotImplementedException("");
ret.Exit();
return ret;
}


+ 1
- 19
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -3,8 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.ControlFlows;
using static Tensorflow.ControlFlowContextDef;

namespace Tensorflow.Operations
{
/// <summary>
@@ -185,23 +184,6 @@ namespace Tensorflow.Operations
return null;
}

/// <summary>
/// Deserializes `context_def` into the appropriate ControlFlowContext.
/// </summary>
/// <param name="context_def">ControlFlowContextDef proto</param>
/// <param name="import_scope">Name scope to add</param>
/// <returns>A ControlFlowContext subclass</returns>
protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef context_def, string import_scope = "")
{
switch (context_def.CtxtCase)
{
case CtxtOneofCase.CondCtxt:
return new CondContext().from_proto(context_def.CondCtxt, import_scope: import_scope);
}
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
}

public object to_proto()
{
throw new NotImplementedException();


+ 0
- 34
src/TensorFlowNET.Core/Operations/Operation.Implicit.cs View File

@@ -1,34 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Convert to other datatype implicitly
/// </summary>
public partial class Operation
{
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;
public static implicit operator Tensor(Operation op) => op.output;

public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);
}
}
}

+ 21
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -248,6 +248,27 @@ namespace Tensorflow
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
}

public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}

public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);
}
/// <summary>


+ 0
- 5
src/TensorFlowNET.Core/Python.cs View File

@@ -27,11 +27,6 @@ namespace Tensorflow
return Enumerable.Range(0, end);
}

protected IEnumerable<int> range(int start, int end)
{
return Enumerable.Range(start, end);
}

public static T New<T>(object args) where T : IPyClass
{
var instance = Activator.CreateInstance<T>();


+ 6
- 6
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -204,12 +204,6 @@ namespace Tensorflow

switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.
@@ -217,6 +211,12 @@ namespace Tensorflow
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)


+ 3
- 3
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -10,9 +10,9 @@ namespace Tensorflow
/// </summary>
public class _ElementFetchMapper : _FetchMapper
{
private Func<List<NDArray>, object> _contraction_fn;
private Func<List<object>, object> _contraction_fn;

public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
{
var g = ops.get_default_graph();
ITensorOrOperation el = null;
@@ -31,7 +31,7 @@ namespace Tensorflow
/// </summary>
/// <param name="values"></param>
/// <returns></returns>
public override NDArray build_results(List<NDArray> values)
public override NDArray build_results(List<object> values)
{
NDArray result = null;



+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow

public NDArray build_results(BaseSession session, NDArray[] tensor_values)
{
var full_values = new List<NDArray>();
var full_values = new List<object>();
if (_final_fetches.Count != tensor_values.Length)
throw new InvalidOperationException("_final_fetches mismatch tensor_values");



+ 13
- 1
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -24,7 +24,19 @@ namespace Tensorflow
{
var type = values[0].GetType();
var nd = new NDArray(type, values.Count);
nd.SetData(values.ToArray());

switch (type.Name)
{
case "Single":
nd.SetData(values.Select(x => (float)x).ToArray());
break;
case "NDArray":
NDArray[] arr = values.Select(x => (NDArray)x).ToArray();
nd = new NDArray(arr);
break;
default:
break;
}
return nd;
}



+ 0
- 4
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -57,8 +57,4 @@ More math/ linalg APIs.</PackageReleaseNotes>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project>

+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -67,7 +67,8 @@ namespace Tensorflow
case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
break;
//case "Byte":
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
/*var bb = nd.Data<byte>();
var bytes = Marshal.AllocHGlobal(bb.Length);
Marshal.Copy(bb, 0, bytes, bb.Length);


+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -195,6 +195,7 @@ namespace Tensorflow
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
default:


+ 0
- 4
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -2,7 +2,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -116,9 +115,6 @@ namespace Tensorflow
case List<ITensorOrOperation> values:
foreach (var element in values) ;
break;
case List<CondContext> values:
foreach (var element in values) ;
break;
default:
throw new NotImplementedException("_build_internal.check_collection_list");
}


BIN
tensorflowlib/runtimes/win-x64/native/tensorflow.dll View File


+ 169
- 0
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -0,0 +1,169 @@
using Newtonsoft.Json;
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using System.Drawing;
using System.Drawing.Drawing2D;
using System.Linq;

namespace TensorFlowNET.Examples
{

public class ObjectDetection : Python, IExample
{
public int Priority => 7;
public bool Enabled { get; set; } = true;
public string Name => "Image Recognition";
public float MIN_SCORE = 0.5f;

string modelDir = "ssd_mobilenet_v1_coco_2018_01_28";
string imageDir = "images";
string pbFile = "frozen_inference_graph.pb";
string labelFile = "mscoco_label_map.pbtxt";
string picFile = "input.jpg";

public bool Run()
{
//buildOutputImage(null);

// read in the input image
var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg"));

var graph = new Graph().as_default();
graph.Import(Path.Join(modelDir, pbFile));

var tensorNum = graph.OperationByName("num_detections").outputs[0];
var tensorBoxes = graph.OperationByName("detection_boxes").outputs[0];
var tensorScores = graph.OperationByName("detection_scores").outputs[0];
var tensorClasses = graph.OperationByName("detection_classes").outputs[0];

var imgTensor = graph.OperationByName("image_tensor").outputs[0];


Tensor[] outTensorArr = new Tensor[] { tensorNum, tensorBoxes, tensorScores, tensorClasses };

with(tf.Session(graph), sess =>
{
var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr));
NDArray[] resultArr = results.Data<NDArray>();
buildOutputImage(resultArr);
});

return true;
}

public void PrepareData()
{
if (!Directory.Exists(modelDir))
Directory.CreateDirectory(modelDir);

if (!File.Exists(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz")))
{
// get model file
string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz";

Utility.Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz");
}

if (!File.Exists(Path.Join(modelDir, "frozen_inference_graph.pb")))
{
Utility.Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./");
}


// download sample picture
if (!Directory.Exists(imageDir))
Directory.CreateDirectory(imageDir);

if (!File.Exists(Path.Join(imageDir, "input.jpg")))
{
string url = $"https://github.com/tensorflow/models/raw/master/research/object_detection/test_images/image2.jpg";
Utility.Web.Download(url, imageDir, "input.jpg");
}

// download the pbtxt file
if (!File.Exists(Path.Join(modelDir, "mscoco_label_map.pbtxt")))
{
string url = $"https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_label_map.pbtxt";
Utility.Web.Download(url, modelDir, "mscoco_label_map.pbtxt");
}
}

private NDArray ReadTensorFromImageFile(string file_name)
{
return with(tf.Graph().as_default(), graph =>
{
var file_reader = tf.read_file(file_name, "file_reader");
var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg");
var casted = tf.cast(decodeJpeg, TF_DataType.TF_UINT8);
var dims_expander = tf.expand_dims(casted, 0);
return with(tf.Session(graph), sess => sess.run(dims_expander));
});
}

private void buildOutputImage(NDArray[] resultArr)
{
// get pbtxt items
PbtxtItems pbTxtItems = PbtxtParser.ParsePbtxtFile(Path.Join(modelDir, "mscoco_label_map.pbtxt"));

// get bitmap
Bitmap bitmap = new Bitmap(Path.Join(imageDir, "input.jpg"));

float[] scores = resultArr[2].Data<float>();

for (int i=0; i<scores.Length; i++)
{
float score = scores[i];
if (score > MIN_SCORE)
{
//var boxes = resultArr[1].Data<float[,,]>();
float[] boxes = resultArr[1].Data<float>();
float top = boxes[i * 4] * bitmap.Height;
float left = boxes[i * 4 + 1] * bitmap.Width;
float bottom = boxes[i * 4 + 2] * bitmap.Height;
float right = boxes[i * 4 + 3] * bitmap.Width;

Rectangle rect = new Rectangle()
{
X = (int)left,
Y = (int)top,
Width = (int)(right - left),
Height = (int)(bottom - top)
};

float[] ids = resultArr[3].Data<float>();

string name = pbTxtItems.items.Where(w => w.id == (int)ids[i]).Select(s=>s.display_name).FirstOrDefault();

drawObjectOnBitmap(bitmap, rect, score, name);
}
}

bitmap.Save(Path.Join(imageDir, "output.jpg"));
}

private void drawObjectOnBitmap(Bitmap bmp, Rectangle rect, float score, string name)
{
using (Graphics graphic = Graphics.FromImage(bmp))
{
graphic.SmoothingMode = SmoothingMode.AntiAlias;
using (Pen pen = new Pen(Color.Red, 2))
{
graphic.DrawRectangle(pen, rect);

Point p = new Point(rect.Right + 5, rect.Top + 5);
string text = string.Format("{0}:{1}%", name, (int)(score * 100));
graphic.DrawString(text, new Font("Verdana", 8), Brushes.Red, p);
}
}
}
}
}

+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,10 +10,10 @@
<PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="SharpZipLib" Version="1.1.0" />
<PackageReference Include="System.Drawing.Common" Version="4.5.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 74
- 0
test/TensorFlowNET.Examples/Utility/PbtxtParser.cs View File

@@ -0,0 +1,74 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace TensorFlowNET.Examples.Utility
{
public class PbtxtItem
{
public string name { get; set; }
public int id { get; set; }
public string display_name { get; set; }
}
public class PbtxtItems
{
public List<PbtxtItem> items { get; set; }
}

public class PbtxtParser
{
public static PbtxtItems ParsePbtxtFile(string filePath)
{
string line;
string newText = "{\"items\":[";

try
{
using (System.IO.StreamReader reader = new System.IO.StreamReader(filePath))
{

while ((line = reader.ReadLine()) != null)
{
string newline = string.Empty;

if (line.Contains("{"))
{
newline = line.Replace("item", "").Trim();
//newText += line.Insert(line.IndexOf("=") + 1, "\"") + "\",";
newText += newline;
}
else if (line.Contains("}"))
{
newText = newText.Remove(newText.Length - 1);
newText += line;
newText += ",";
}
else
{
newline = line.Replace(":", "\":").Trim();
newline = "\"" + newline;// newline.Insert(0, "\"");
newline += ",";

newText += newline;
}

}

newText = newText.Remove(newText.Length - 1);
newText += "]}";

reader.Close();
}

PbtxtItems items = JsonConvert.DeserializeObject<PbtxtItems>(newText);

return items;
}
catch (Exception ex)
{
return null;
}
}
}
}

+ 0
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -22,7 +22,6 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
<ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" />
</ItemGroup>


Loading…
Cancel
Save