From 639d4b0fd394deccbeeed19ca78b5a273337f761 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 25 Apr 2019 06:20:01 -0500 Subject: [PATCH] unified the incoming parameter of the _FetchMapper as NDArray. --- README.md | 8 ++++---- .../Sessions/_ElementFetchMapper.cs | 6 +++--- src/TensorFlowNET.Core/Sessions/_FetchHandler.cs | 2 +- src/TensorFlowNET.Core/Sessions/_FetchMapper.cs | 13 +++---------- test/TensorFlowNET.Examples/KMeansClustering.cs | 15 ++++++++++++++- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 04f29b6e..a2a205fd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # TensorFlow.NET -TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. +TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) [![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) @@ -8,7 +8,7 @@ TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.ten [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) -TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). +TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). ![tensors_flowing](docs/assets/tensors_flowing.gif) @@ -24,14 +24,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr ### How to use -Install TensorFlow.NET through NuGet. +Install TF.NET through NuGet. ```sh PM> Install-Package TensorFlow.NET ``` If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. -Import tensorflow.net. +Import TF.NET. ```cs using Tensorflow; diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index e06f9f2e..e18eebc0 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -10,9 +10,9 @@ namespace Tensorflow /// public class _ElementFetchMapper : _FetchMapper { - private Func, object> _contraction_fn; + private Func, object> _contraction_fn; - public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) { var g = ops.get_default_graph(); ITensorOrOperation el = null; @@ -31,7 +31,7 @@ namespace Tensorflow /// /// /// - public override NDArray build_results(List values) + public override NDArray build_results(List values) { NDArray result = null; diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 9a6a738a..2c5b55f6 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -42,7 +42,7 @@ namespace Tensorflow public NDArray build_results(BaseSession session, NDArray[] tensor_values) { - var full_values = new List(); + var full_values = new List(); if (_final_fetches.Count != tensor_values.Length) throw new InvalidOperationException("_final_fetches mismatch tensor_values"); diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 6d4b3b40..4b2691a8 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -17,21 +17,14 @@ namespace Tensorflow if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else - return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); } - public virtual NDArray build_results(List values) + public virtual NDArray build_results(List values) { var type = values[0].GetType(); var nd = new NDArray(type, values.Count); - - switch (type.Name) - { - case "Single": - nd.SetData(values.Select(x => (float)x).ToArray()); - break; - } - + nd.SetData(values.ToArray()); return nd; } diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs index 9be1ad0d..d9e2de47 100644 --- a/test/TensorFlowNET.Examples/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/KMeansClustering.cs @@ -62,10 +62,23 @@ namespace TensorFlowNET.Examples sess.run(init_op, new FeedItem(X, full_data_x)); // Training + NDArray result = null; foreach(var i in range(1, num_steps + 1)) { - var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); + result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); + if (i % 2 == 0 || i == 1) + print($"Step {i}, Avg Distance: {result[1]}"); } + + var idx = result[2]; + + // Assign a label to each centroid + // Count total number of labels per centroid, using the label of each training + // sample to their closest centroid (given by 'idx') + var counts = np.zeros(k, num_classes); + foreach (var i in range(idx.len)) + counts[idx[i]] += mnist.train.labels[i]; + }); return false;