Browse Source

unified the incoming parameter of the _FetchMapper as NDArray.

tags/v0.9
Oceania2018 6 years ago
parent
commit
639d4b0fd3
5 changed files with 25 additions and 19 deletions
  1. +4
    -4
      README.md
  2. +3
    -3
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +3
    -10
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  5. +14
    -1
      test/TensorFlowNET.Examples/KMeansClustering.cs

+ 4
- 4
README.md View File

@@ -1,5 +1,5 @@
# TensorFlow.NET # TensorFlow.NET
TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.


[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) [![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -8,7 +8,7 @@ TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.ten
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)


TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).


![tensors_flowing](docs/assets/tensors_flowing.gif) ![tensors_flowing](docs/assets/tensors_flowing.gif)


@@ -24,14 +24,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr


### How to use ### How to use


Install TensorFlow.NET through NuGet.
Install TF.NET through NuGet.
```sh ```sh
PM> Install-Package TensorFlow.NET PM> Install-Package TensorFlow.NET
``` ```


If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows.


Import tensorflow.net.
Import TF.NET.


```cs ```cs
using Tensorflow; using Tensorflow;


+ 3
- 3
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -10,9 +10,9 @@ namespace Tensorflow
/// </summary> /// </summary>
public class _ElementFetchMapper : _FetchMapper public class _ElementFetchMapper : _FetchMapper
{ {
private Func<List<object>, object> _contraction_fn;
private Func<List<NDArray>, object> _contraction_fn;


public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
{ {
var g = ops.get_default_graph(); var g = ops.get_default_graph();
ITensorOrOperation el = null; ITensorOrOperation el = null;
@@ -31,7 +31,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="values"></param> /// <param name="values"></param>
/// <returns></returns> /// <returns></returns>
public override NDArray build_results(List<object> values)
public override NDArray build_results(List<NDArray> values)
{ {
NDArray result = null; NDArray result = null;




+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow


public NDArray build_results(BaseSession session, NDArray[] tensor_values) public NDArray build_results(BaseSession session, NDArray[] tensor_values)
{ {
var full_values = new List<object>();
var full_values = new List<NDArray>();
if (_final_fetches.Count != tensor_values.Length) if (_final_fetches.Count != tensor_values.Length)
throw new InvalidOperationException("_final_fetches mismatch tensor_values"); throw new InvalidOperationException("_final_fetches mismatch tensor_values");




+ 3
- 10
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -17,21 +17,14 @@ namespace Tensorflow
if (fetch.GetType().IsArray) if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches); return new _ListFetchMapper(fetches);
else else
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]);
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]);
} }


public virtual NDArray build_results(List<object> values)
public virtual NDArray build_results(List<NDArray> values)
{ {
var type = values[0].GetType(); var type = values[0].GetType();
var nd = new NDArray(type, values.Count); var nd = new NDArray(type, values.Count);

switch (type.Name)
{
case "Single":
nd.SetData(values.Select(x => (float)x).ToArray());
break;
}

nd.SetData(values.ToArray());
return nd; return nd;
} }




+ 14
- 1
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -62,10 +62,23 @@ namespace TensorFlowNET.Examples
sess.run(init_op, new FeedItem(X, full_data_x)); sess.run(init_op, new FeedItem(X, full_data_x));


// Training // Training
NDArray result = null;
foreach(var i in range(1, num_steps + 1)) foreach(var i in range(1, num_steps + 1))
{ {
var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
if (i % 2 == 0 || i == 1)
print($"Step {i}, Avg Distance: {result[1]}");
} }

var idx = result[2];

// Assign a label to each centroid
// Count total number of labels per centroid, using the label of each training
// sample to their closest centroid (given by 'idx')
var counts = np.zeros(k, num_classes);
foreach (var i in range(idx.len))
counts[idx[i]] += mnist.train.labels[i];

}); });


return false; return false;


Loading…
Cancel
Save