Browse Source

_ListFetchMapper for multiple fetch in Operation and Tensor.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
9c161b14dc
6 changed files with 86 additions and 24 deletions
  1. +1
    -7
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  2. +14
    -3
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  3. +20
    -7
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  4. +18
    -0
      src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs
  5. +11
    -6
      test/TensorFlowNET.Examples/LogisticRegression.cs
  6. +22
    -1
      test/TensorFlowNET.Examples/Utility/DataSet.cs

+ 1
- 7
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -10,7 +10,6 @@ namespace Tensorflow
/// </summary>
public class _ElementFetchMapper : _FetchMapper
{
private List<object> _unique_fetches = new List<object>();
private Func<List<object>, object> _contraction_fn;

public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
@@ -32,7 +31,7 @@ namespace Tensorflow
/// </summary>
/// <param name="values"></param>
/// <returns></returns>
public NDArray build_results(List<object> values)
public override NDArray build_results(List<object> values)
{
NDArray result = null;

@@ -51,10 +50,5 @@ namespace Tensorflow

return result;
}

public List<object> unique_fetches()
{
return _unique_fetches;
}
}
}

+ 14
- 3
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
/// </summary>
public class _FetchHandler
{
private _ElementFetchMapper _fetch_mapper;
private _FetchMapper _fetch_mapper;
private List<Tensor> _fetches = new List<Tensor>();
private List<bool> _ops = new List<bool>();
private List<Tensor> _final_fetches = new List<Tensor>();
@@ -18,7 +18,7 @@ namespace Tensorflow

public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
_fetch_mapper = _FetchMapper.for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
{
switch (fetch)
@@ -58,7 +58,18 @@ namespace Tensorflow
{
var value = tensor_values[j];
j += 1;
full_values.Add(value);
switch (value.dtype.Name)
{
case "Int32":
full_values.Add(value.Data<int>(0));
break;
case "Single":
full_values.Add(value.Data<float>(0));
break;
case "Double":
full_values.Add(value.Data<double>(0));
break;
}
}
i += 1;
}


+ 20
- 7
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;

@@ -6,14 +7,26 @@ namespace Tensorflow
{
public class _FetchMapper
{
public _ElementFetchMapper for_fetch(object fetch)
protected List<object> _unique_fetches = new List<object>();

public static _FetchMapper for_fetch(object fetch)
{
var fetches = new object[] { fetch };
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };

if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches);
else
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]);
}

return new _ElementFetchMapper(fetches, (List<object> fetched_vals) =>
{
return fetched_vals[0];
});
public virtual NDArray build_results(List<object> values)
{
return values.ToArray();
}

public virtual List<object> unique_fetches()
{
return _unique_fetches;
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public class _ListFetchMapper : _FetchMapper
{
private _FetchMapper[] _mappers;

public _ListFetchMapper(object[] fetches)
{
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray();
_unique_fetches.AddRange(fetches);
}
}
}

+ 11
- 6
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax

// Minimize error using cross entropy
var log = tf.log(pred);
var mul = y * log;
var sum = tf.reduce_sum(mul, reduction_indices: 1);
var neg = -sum;
var cost = tf.reduce_mean(neg);
var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1));

// Gradient Descent
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
@@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples
{
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
// Run optimization op (backprop) and cost op (to get loss value)
var (_, c) = sess.run(optimizer,
var result = sess.run(new object[] { optimizer, cost },
new FeedItem(x, batch_xs),
new FeedItem(y, batch_ys));

var c = (float)result[1];
// Compute average loss
avg_cost += c / total_batch;
}

// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
}

print("Optimization Finished!");

// Test model
});
}
}


+ 22
- 1
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility
// Finished epoch
_epochs_completed += 1;

throw new NotImplementedException("next_batch");
// Get the rest examples in this epoch
var rest_num_examples = _num_examples - start;
var images_rest_part = _images[np.arange(start, _num_examples)];
var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data
if (shuffle)
{
var perm = np.arange(_num_examples);
np.random.shuffle(perm);
_images = images[perm];
_labels = labels[perm];
}

start = 0;
_index_in_epoch = batch_size - rest_num_examples;
var end = _index_in_epoch;
var images_new_part = _images[np.arange(start, end)];
var labels_new_part = _labels[np.arange(start, end)];

/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
return (images_new_part, labels_new_part);
}
else
{


Loading…
Cancel
Save