Browse Source

unified the incoming parameter of the _FetchMapper as NDArray.

tags/v0.9
Oceania2018 6 years ago
parent
commit
639d4b0fd3
5 changed files with 25 additions and 19 deletions
  1. +4
    -4
      README.md
  2. +3
    -3
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +3
    -10
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  5. +14
    -1
      test/TensorFlowNET.Examples/KMeansClustering.cs

+ 4
- 4
README.md View File

@@ -1,5 +1,5 @@
# TensorFlow.NET
TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.

[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -8,7 +8,7 @@ TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.ten
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)

TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).

![tensors_flowing](docs/assets/tensors_flowing.gif)

@@ -24,14 +24,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr

### How to use

Install TensorFlow.NET through NuGet.
Install TF.NET through NuGet.
```sh
PM> Install-Package TensorFlow.NET
```

If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows.

Import tensorflow.net.
Import TF.NET.

```cs
using Tensorflow;


+ 3
- 3
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -10,9 +10,9 @@ namespace Tensorflow
/// </summary>
public class _ElementFetchMapper : _FetchMapper
{
private Func<List<object>, object> _contraction_fn;
private Func<List<NDArray>, object> _contraction_fn;

public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
{
var g = ops.get_default_graph();
ITensorOrOperation el = null;
@@ -31,7 +31,7 @@ namespace Tensorflow
/// </summary>
/// <param name="values"></param>
/// <returns></returns>
public override NDArray build_results(List<object> values)
public override NDArray build_results(List<NDArray> values)
{
NDArray result = null;



+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow

public NDArray build_results(BaseSession session, NDArray[] tensor_values)
{
var full_values = new List<object>();
var full_values = new List<NDArray>();
if (_final_fetches.Count != tensor_values.Length)
throw new InvalidOperationException("_final_fetches mismatch tensor_values");



+ 3
- 10
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -17,21 +17,14 @@ namespace Tensorflow
if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches);
else
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]);
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]);
}

public virtual NDArray build_results(List<object> values)
public virtual NDArray build_results(List<NDArray> values)
{
var type = values[0].GetType();
var nd = new NDArray(type, values.Count);

switch (type.Name)
{
case "Single":
nd.SetData(values.Select(x => (float)x).ToArray());
break;
}

nd.SetData(values.ToArray());
return nd;
}



+ 14
- 1
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -62,10 +62,23 @@ namespace TensorFlowNET.Examples
sess.run(init_op, new FeedItem(X, full_data_x));

// Training
NDArray result = null;
foreach(var i in range(1, num_steps + 1))
{
var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
if (i % 2 == 0 || i == 1)
print($"Step {i}, Avg Distance: {result[1]}");
}

var idx = result[2];

// Assign a label to each centroid
// Count total number of labels per centroid, using the label of each training
// sample to their closest centroid (given by 'idx')
var counts = np.zeros(k, num_classes);
foreach (var i in range(idx.len))
counts[idx[i]] += mnist.train.labels[i];

});

return false;


Loading…
Cancel
Save