Browse Source

Add Nearest Neighbor

release v0.5.1
tags/v0.9
Oceania2018 6 years ago
parent
commit
8654b41c2a
19 changed files with 184 additions and 40 deletions
  1. +1
    -0
      README.md
  2. +3
    -0
      docs/source/NearestNeighbor.md
  3. +1
    -0
      docs/source/index.rst
  4. +12
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  6. +31
    -15
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +28
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  8. +6
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  9. +4
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  10. +8
    -6
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  11. +6
    -3
      src/TensorFlowNET.Core/Variables/variables.py.cs
  12. +1
    -1
      src/TensorFlowNET.Core/ops.py.cs
  13. +1
    -1
      test/TensorFlowNET.Examples/ImageRecognition.cs
  14. +1
    -1
      test/TensorFlowNET.Examples/LogisticRegression.cs
  15. +70
    -0
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  16. +6
    -6
      test/TensorFlowNET.Examples/Program.cs
  17. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  18. +1
    -1
      test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs
  19. +2
    -2
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 1
- 0
README.md View File

@@ -73,6 +73,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
* [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs) * [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs)
* [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs) * [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs)
* [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs) * [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs)
* [Nearest Neighbor](test/TensorFlowNET.Examples/NearestNeighbor.cs)
* [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs) * [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs)
* [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs) * [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs)
* [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs) * [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs)


+ 3
- 0
docs/source/NearestNeighbor.md View File

@@ -0,0 +1,3 @@
# Chapter. Nearest Neighbor

The nearest neighbour algorithm was one of the first algorithms used to solve the travelling salesman problem. In it, the salesman starts at a random city and repeatedly visits the nearest city until all have been visited. It quickly yields a short tour, but usually not the optimal one.

+ 1
- 0
docs/source/index.rst View File

@@ -27,4 +27,5 @@ Welcome to TensorFlow.NET's documentation!
EagerMode EagerMode
LinearRegression LinearRegression
LogisticRegression LogisticRegression
NearestNeighbor
ImageRecognition ImageRecognition

+ 12
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -6,9 +6,18 @@ namespace Tensorflow
{ {
public static partial class tf public static partial class tf
{ {
public static Tensor abs(Tensor x, string name = null)
=> math_ops.abs(x, name);

public static Tensor add(Tensor a, Tensor b) public static Tensor add(Tensor a, Tensor b)
=> gen_math_ops.add(a, b); => gen_math_ops.add(a, b);


public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
=> gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name);

public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
=> gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name);

public static Tensor sub(Tensor a, Tensor b) public static Tensor sub(Tensor a, Tensor b)
=> gen_math_ops.sub(a, b); => gen_math_ops.sub(a, b);


@@ -27,6 +36,9 @@ namespace Tensorflow
public static Tensor multiply(Tensor x, Tensor y) public static Tensor multiply(Tensor x, Tensor y)
=> gen_math_ops.mul(x, y); => gen_math_ops.mul(x, y);


public static Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);

public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");




+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -355,7 +355,7 @@ namespace Tensorflow
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
} }


public object get_collection(string name, string scope = "")
public object get_collection(string name, string scope = null)
{ {
return _collections.ContainsKey(name) ? _collections[name] : null; return _collections.ContainsKey(name) ? _collections[name] : null;
} }


+ 31
- 15
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -9,6 +9,30 @@ namespace Tensorflow
public static class gen_math_ops public static class gen_math_ops
{ {
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static OpDefLibrary _op_def_lib = new OpDefLibrary();
/// <summary>
/// Returns the index with the largest value across dimensions of a tensor.
/// </summary>
/// <param name="input"></param>
/// <param name="dimension"></param>
/// <param name="output_type"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
=> _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).outputs[0];
/// <summary>
/// Returns the index with the smallest value across dimensions of a tensor.
/// </summary>
/// <param name="input"></param>
/// <param name="dimension"></param>
/// <param name="output_type"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type= TF_DataType.TF_INT64, string name= null)
=>_op_def_lib._apply_op_helper("ArgMin", name, args: new { input, dimension, output_type }).outputs[0];
/// <summary> /// <summary>
/// Computes the mean of elements across dimensions of a tensor. /// Computes the mean of elements across dimensions of a tensor.
/// Reduces `input` along the dimensions given in `axis`. Unless /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in /// `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. /// Reduces `input` along the dimensions given in `axis`. Unless /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in /// `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1.
@@ -207,6 +231,13 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor _abs(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Abs", name, new { x });
return _op.outputs[0];
}
public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null) public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });
@@ -249,20 +280,5 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
/// <summary>
/// Returns the index with the largest value across dimensions of a tensor.
/// </summary>
/// <param name="input"></param>
/// <param name="dimension"></param>
/// <param name="output_type"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type });
return _op.outputs[0];
}
} }
} }

src/TensorFlowNET.Core/Operations/math_ops.py.cs → src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -6,9 +6,36 @@ using Tensorflow.Framework;


namespace Tensorflow namespace Tensorflow
{ {
/// <summary>
/// python\ops\math_ops.py
/// </summary>
public class math_ops : Python public class math_ops : Python
{ {
public static Tensor add(Tensor x, Tensor y, string name = null) => gen_math_ops.add(x, y, name);
public static Tensor abs(Tensor x, string name = null)
{
return with(ops.name_scope(name, "Abs", new { x }), scope =>
{
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.is_complex())
throw new NotImplementedException("math_ops.abs for dtype.is_complex");
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
return gen_math_ops._abs(x, name: name);
});
}

public static Tensor add(Tensor x, Tensor y, string name = null)
=> gen_math_ops.add(x, y, name);

public static Tensor add(Tensor x, string name = null)
{
return with(ops.name_scope(name, "Abs", new { x }), scope =>
{
name = scope;
x = ops.convert_to_tensor(x, name: "x");

return gen_math_ops._abs(x, name: name);
});
}


public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{ {

+ 6
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -222,6 +222,12 @@ namespace Tensorflow
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims); nd = np.array(ints).reshape(ndims);
break; break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size]; var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++) for (ulong i = 0; i < tensor.size; i++)


+ 4
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -65,6 +65,9 @@ namespace Tensorflow
case "Int32": case "Int32":
full_values.Add(value.Data<int>(0)); full_values.Add(value.Data<int>(0));
break; break;
case "Int64":
full_values.Add(value.Data<long>(0));
break;
case "Single": case "Single":
full_values.Add(value.Data<float>(0)); full_values.Add(value.Data<float>(0));
break; break;
@@ -78,7 +81,7 @@ namespace Tensorflow
} }
else else
{ {
full_values.Add(value[np.arange(1)]);
full_values.Add(value[np.arange(0, value.shape[0])]);
} }
} }
i += 1; i += 1;


+ 8
- 6
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -4,7 +4,7 @@
<TargetFramework>netstandard2.0</TargetFramework> <TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<Version>0.5.0</Version>
<Version>0.5.1</Version>
<Authors>Haiping Chen</Authors> <Authors>Haiping Chen</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,11 +16,13 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow binding in .NET Standard. <Description>Google's TensorFlow binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description> Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.5.0.0</AssemblyVersion>
<PackageReleaseNotes>Add Logistic Regression to do MNIST.
Add a lot of APIs to build neural networks model</PackageReleaseNotes>
<AssemblyVersion>0.5.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.5:
Added Nearest Neighbor.
Add a lot of APIs to build neural networks model.
Bug fix.</PackageReleaseNotes>
<LangVersion>7.2</LangVersion> <LangVersion>7.2</LangVersion>
<FileVersion>0.5.0.0</FileVersion>
<FileVersion>0.5.1.0</FileVersion>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -44,7 +46,7 @@ Add a lot of APIs to build neural networks model</PackageReleaseNotes>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> <PackageReference Include="Google.Protobuf" Version="3.7.0" />
<PackageReference Include="NumSharp" Version="0.8.1" />
<PackageReference Include="NumSharp" Version="0.8.2" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 6
- 3
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -47,11 +47,11 @@ namespace Tensorflow
/// special tokens filters by prefix. /// special tokens filters by prefix.
/// </param> /// </param>
/// <returns>A list of `Variable` objects.</returns> /// <returns>A list of `Variable` objects.</returns>
public static List<RefVariable> global_variables(string scope = "")
public static List<RefVariable> global_variables(string scope = null)
{ {
var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);


return result as List<RefVariable>;
return result == null ? new List<RefVariable>() : result as List<RefVariable>;
} }


/// <summary> /// <summary>
@@ -62,7 +62,10 @@ namespace Tensorflow
/// <returns>An Op that run the initializers of all the specified variables.</returns> /// <returns>An Op that run the initializers of all the specified variables.</returns>
public static Operation variables_initializer(RefVariable[] var_list, string name = "init") public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
{ {
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name);
if (var_list.Length > 0)
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name);
else
return gen_control_flow_ops.no_op(name: name);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow
/// list contains the values in the order under which they were /// list contains the values in the order under which they were
/// collected. /// collected.
/// </returns> /// </returns>
public static object get_collection(string key, string scope = "")
public static object get_collection(string key, string scope = null)
{ {
return get_default_graph().get_collection(key, scope); return get_default_graph().get_collection(key, scope);
} }


+ 1
- 1
test/TensorFlowNET.Examples/ImageRecognition.cs View File

@@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples
{ {
public class ImageRecognition : Python, IExample public class ImageRecognition : Python, IExample
{ {
public int Priority => 5;
public int Priority => 6;
public bool Enabled => true; public bool Enabled => true;
public string Name => "Image Recognition"; public string Name => "Image Recognition";




+ 1
- 1
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -98,7 +98,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
} }


public void SaveModel(Session sess) public void SaveModel(Session sess)


+ 70
- 0
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -0,0 +1,70 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;

namespace TensorFlowNET.Examples
{
/// <summary>
/// A nearest neighbor learning algorithm example
/// This example is using the MNIST database of handwritten digits
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py
/// </summary>
public class NearestNeighbor : Python, IExample
{
public int Priority => 5;
public bool Enabled => true;
public string Name => "Nearest Neighbor";
Datasets mnist;
NDArray Xtr, Ytr, Xte, Yte;

public bool Run()
{
// tf Graph Input
var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784));
var xte = tf.placeholder(tf.float32, new TensorShape(784));

// Nearest Neighbor calculation using L1 Distance
// Calculate L1 Distance
var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1);
// Prediction: Get min distance index (Nearest neighbor)
var pred = tf.arg_min(distance, 0);

float accuracy = 0f;
// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();
with(tf.Session(), sess =>
{
// Run the initializer
sess.run(init);

PrepareData();

foreach(int i in range(Xte.shape[0]))
{
// Get nearest neighbor
long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
// Get nearest neighbor class label and compare it to its true label
print($"Test {i} Prediction: {np.argmax(Ytr[nn_index])} True Class: {np.argmax(Yte[i] as NDArray)}");
// Calculate accuracy
if (np.argmax(Ytr[nn_index]) == np.argmax(Yte[i] as NDArray))
accuracy += 1f/ Xte.shape[0];
}

print($"Accuracy: {accuracy}");
});

return accuracy > 0.9;
}

public void PrepareData()
{
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
// In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(200); // 200 for testing
}
}
}

+ 6
- 6
test/TensorFlowNET.Examples/Program.cs View File

@@ -32,11 +32,11 @@ namespace TensorFlowNET.Examples
{ {
if (example.Enabled) if (example.Enabled)
if (example.Run()) if (example.Run())
success.Add($"{example.Priority} {example.Name}");
success.Add($"Example {example.Priority}: {example.Name}");
else else
errors.Add($"{example.Priority} {example.Name}");
errors.Add($"Example {example.Priority}: {example.Name}");
else else
disabled.Add($"{example.Priority} {example.Name}");
disabled.Add($"Example {example.Priority}: {example.Name}");
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -46,9 +46,9 @@ namespace TensorFlowNET.Examples
Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White);
} }


success.ForEach(x => Console.WriteLine($"{x} example is OK!", Color.Green));
disabled.ForEach(x => Console.WriteLine($"{x} example is Disabled!", Color.Tan));
errors.ForEach(x => Console.WriteLine($"{x} example is Failed!", Color.Red));
success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green));
disabled.ForEach(x => Console.WriteLine($"{x} is Disabled!", Color.Tan));
errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red));


Console.ReadLine(); Console.ReadLine();
} }


+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -8,7 +8,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Colorful.Console" Version="1.2.9" /> <PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> <PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="NumSharp" Version="0.8.1" />
<PackageReference Include="NumSharp" Version="0.8.2" />
<PackageReference Include="SharpZipLib" Version="1.1.0" /> <PackageReference Include="SharpZipLib" Version="1.1.0" />
</ItemGroup> </ItemGroup>




+ 1
- 1
test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs View File

@@ -11,7 +11,7 @@ namespace TensorFlowNET.Examples
{ {
public class TextClassificationWithMovieReviews : Python, IExample public class TextClassificationWithMovieReviews : Python, IExample
{ {
public int Priority => 6;
public int Priority => 7;
public bool Enabled => false; public bool Enabled => false;
public string Name => "Movie Reviews"; public string Name => "Movie Reviews";




+ 2
- 2
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -19,8 +19,8 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.8.1" />
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />
<PackageReference Include="NumSharp" Version="0.8.2" />
<PackageReference Include="TensorFlow.NET" Version="0.5.0" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


Loading…
Cancel
Save