@@ -1,5 +1,5 @@ | |||||
# TensorFlow.NET | # TensorFlow.NET | ||||
TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
[](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
[](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | ||||
@@ -8,7 +8,7 @@ TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.ten | |||||
[](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | ||||
[](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
 |  | ||||
@@ -24,14 +24,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr | |||||
### How to use | ### How to use | ||||
Install TensorFlow.NET through NuGet. | |||||
Install TF.NET through NuGet. | |||||
```sh | ```sh | ||||
PM> Install-Package TensorFlow.NET | PM> Install-Package TensorFlow.NET | ||||
``` | ``` | ||||
If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. | If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. | ||||
Import tensorflow.net. | |||||
Import TF.NET. | |||||
```cs | ```cs | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -10,9 +10,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class _ElementFetchMapper : _FetchMapper | public class _ElementFetchMapper : _FetchMapper | ||||
{ | { | ||||
private Func<List<object>, object> _contraction_fn; | |||||
private Func<List<NDArray>, object> _contraction_fn; | |||||
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||||
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn) | |||||
{ | { | ||||
var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
ITensorOrOperation el = null; | ITensorOrOperation el = null; | ||||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="values"></param> | /// <param name="values"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public override NDArray build_results(List<object> values) | |||||
public override NDArray build_results(List<NDArray> values) | |||||
{ | { | ||||
NDArray result = null; | NDArray result = null; | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow | |||||
public NDArray build_results(BaseSession session, NDArray[] tensor_values) | public NDArray build_results(BaseSession session, NDArray[] tensor_values) | ||||
{ | { | ||||
var full_values = new List<object>(); | |||||
var full_values = new List<NDArray>(); | |||||
if (_final_fetches.Count != tensor_values.Length) | if (_final_fetches.Count != tensor_values.Length) | ||||
throw new InvalidOperationException("_final_fetches mismatch tensor_values"); | throw new InvalidOperationException("_final_fetches mismatch tensor_values"); | ||||
@@ -17,21 +17,14 @@ namespace Tensorflow | |||||
if (fetch.GetType().IsArray) | if (fetch.GetType().IsArray) | ||||
return new _ListFetchMapper(fetches); | return new _ListFetchMapper(fetches); | ||||
else | else | ||||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||||
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]); | |||||
} | } | ||||
public virtual NDArray build_results(List<object> values) | |||||
public virtual NDArray build_results(List<NDArray> values) | |||||
{ | { | ||||
var type = values[0].GetType(); | var type = values[0].GetType(); | ||||
var nd = new NDArray(type, values.Count); | var nd = new NDArray(type, values.Count); | ||||
switch (type.Name) | |||||
{ | |||||
case "Single": | |||||
nd.SetData(values.Select(x => (float)x).ToArray()); | |||||
break; | |||||
} | |||||
nd.SetData(values.ToArray()); | |||||
return nd; | return nd; | ||||
} | } | ||||
@@ -62,10 +62,23 @@ namespace TensorFlowNET.Examples | |||||
sess.run(init_op, new FeedItem(X, full_data_x)); | sess.run(init_op, new FeedItem(X, full_data_x)); | ||||
// Training | // Training | ||||
NDArray result = null; | |||||
foreach(var i in range(1, num_steps + 1)) | foreach(var i in range(1, num_steps + 1)) | ||||
{ | { | ||||
var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); | |||||
result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); | |||||
if (i % 2 == 0 || i == 1) | |||||
print($"Step {i}, Avg Distance: {result[1]}"); | |||||
} | } | ||||
var idx = result[2]; | |||||
// Assign a label to each centroid | |||||
// Count total number of labels per centroid, using the label of each training | |||||
// sample to their closest centroid (given by 'idx') | |||||
var counts = np.zeros(k, num_classes); | |||||
foreach (var i in range(idx.len)) | |||||
counts[idx[i]] += mnist.train.labels[i]; | |||||
}); | }); | ||||
return false; | return false; | ||||