From 9c161b14dc06df1428161ef593bae1d6e2204be4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 22 Mar 2019 01:12:06 -0500 Subject: [PATCH] _ListFetchMapper for multiple fetch in Operation and Tensor. --- .../Sessions/_ElementFetchMapper.cs | 8 +----- .../Sessions/_FetchHandler.cs | 17 +++++++++--- .../Sessions/_FetchMapper.cs | 27 ++++++++++++++----- .../Sessions/_ListFetchMapper.cs | 18 +++++++++++++ .../LogisticRegression.cs | 17 +++++++----- .../TensorFlowNET.Examples/Utility/DataSet.cs | 23 +++++++++++++++- 6 files changed, 86 insertions(+), 24 deletions(-) create mode 100644 src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index cec214a4..c0de60ee 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -10,7 +10,6 @@ namespace Tensorflow /// public class _ElementFetchMapper : _FetchMapper { - private List _unique_fetches = new List(); private Func, object> _contraction_fn; public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) @@ -32,7 +31,7 @@ namespace Tensorflow /// /// /// - public NDArray build_results(List values) + public override NDArray build_results(List values) { NDArray result = null; @@ -51,10 +50,5 @@ namespace Tensorflow return result; } - - public List unique_fetches() - { - return _unique_fetches; - } } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index e45e3823..01231597 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -10,7 +10,7 @@ namespace Tensorflow /// public class _FetchHandler { - private _ElementFetchMapper _fetch_mapper; + private _FetchMapper _fetch_mapper; private List _fetches = new List(); private List _ops = new List(); private List _final_fetches = new List(); @@ -18,7 +18,7 @@ namespace Tensorflow public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = new _FetchMapper().for_fetch(fetches); + _fetch_mapper = _FetchMapper.for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) @@ -58,7 +58,18 @@ namespace Tensorflow { var value = tensor_values[j]; j += 1; - full_values.Add(value); + switch (value.dtype.Name) + { + case "Int32": + full_values.Add(value.Data(0)); + break; + case "Single": + full_values.Add(value.Data(0)); + break; + case "Double": + full_values.Add(value.Data(0)); + break; + } } i += 1; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index fc9d6b43..038e9971 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -6,14 +7,26 @@ namespace Tensorflow { public class _FetchMapper { - public _ElementFetchMapper for_fetch(object fetch) + protected List _unique_fetches = new List(); + + public static _FetchMapper for_fetch(object fetch) { - var fetches = new object[] { fetch }; + var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; + + if (fetch.GetType().IsArray) + return new _ListFetchMapper(fetches); + else + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + } - return new _ElementFetchMapper(fetches, (List fetched_vals) => - { - return fetched_vals[0]; - }); + public virtual NDArray build_results(List values) + { + return values.ToArray(); + } + + public virtual List unique_fetches() + { + return _unique_fetches; } } } diff --git a/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs new file mode 100644 index 00000000..f94a19da --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class _ListFetchMapper : _FetchMapper + { + private _FetchMapper[] _mappers; + + public _ListFetchMapper(object[] fetches) + { + _mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); + _unique_fetches.AddRange(fetches); + } + } +} diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 9920c274..38c124cc 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax // Minimize error using cross entropy - var log = tf.log(pred); - var mul = y * log; - var sum = tf.reduce_sum(mul, reduction_indices: 1); - var neg = -sum; - var cost = tf.reduce_mean(neg); + var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); // Gradient Descent var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); @@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples { var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); // Run optimization op (backprop) and cost op (to get loss value) - var (_, c) = sess.run(optimizer, + var result = sess.run(new object[] { optimizer, cost }, new FeedItem(x, batch_xs), new FeedItem(y, batch_ys)); + var c = (float)result[1]; // Compute average loss avg_cost += c / total_batch; } + + // Display logs per epoch step + if ((epoch + 1) % display_step == 0) + print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}"); } + + print("Optimization Finished!"); + + // Test model }); } } diff --git a/test/TensorFlowNET.Examples/Utility/DataSet.cs b/test/TensorFlowNET.Examples/Utility/DataSet.cs index bd7b0f79..59b86a63 100644 --- a/test/TensorFlowNET.Examples/Utility/DataSet.cs +++ b/test/TensorFlowNET.Examples/Utility/DataSet.cs @@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility // Finished epoch _epochs_completed += 1; - throw new NotImplementedException("next_batch"); + // Get the rest examples in this epoch + var rest_num_examples = _num_examples - start; + var images_rest_part = _images[np.arange(start, _num_examples)]; + var labels_rest_part = _labels[np.arange(start, _num_examples)]; + // Shuffle the data + if (shuffle) + { + var perm = np.arange(_num_examples); + np.random.shuffle(perm); + _images = images[perm]; + _labels = labels[perm]; + } + + start = 0; + _index_in_epoch = batch_size - rest_num_examples; + var end = _index_in_epoch; + var images_new_part = _images[np.arange(start, end)]; + var labels_new_part = _labels[np.arange(start, end)]; + + /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0), + np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/ + return (images_new_part, labels_new_part); } else {