@@ -10,7 +10,6 @@ namespace Tensorflow | |||
/// </summary> | |||
public class _ElementFetchMapper : _FetchMapper | |||
{ | |||
private List<object> _unique_fetches = new List<object>(); | |||
private Func<List<object>, object> _contraction_fn; | |||
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||
@@ -32,7 +31,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <returns></returns> | |||
public NDArray build_results(List<object> values) | |||
public override NDArray build_results(List<object> values) | |||
{ | |||
NDArray result = null; | |||
@@ -51,10 +50,5 @@ namespace Tensorflow | |||
return result; | |||
} | |||
public List<object> unique_fetches() | |||
{ | |||
return _unique_fetches; | |||
} | |||
} | |||
} |
@@ -10,7 +10,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public class _FetchHandler | |||
{ | |||
private _ElementFetchMapper _fetch_mapper; | |||
private _FetchMapper _fetch_mapper; | |||
private List<Tensor> _fetches = new List<Tensor>(); | |||
private List<bool> _ops = new List<bool>(); | |||
private List<Tensor> _final_fetches = new List<Tensor>(); | |||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
{ | |||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
_fetch_mapper = _FetchMapper.for_fetch(fetches); | |||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
{ | |||
switch (fetch) | |||
@@ -58,7 +58,18 @@ namespace Tensorflow | |||
{ | |||
var value = tensor_values[j]; | |||
j += 1; | |||
full_values.Add(value); | |||
switch (value.dtype.Name) | |||
{ | |||
case "Int32": | |||
full_values.Add(value.Data<int>(0)); | |||
break; | |||
case "Single": | |||
full_values.Add(value.Data<float>(0)); | |||
break; | |||
case "Double": | |||
full_values.Add(value.Data<double>(0)); | |||
break; | |||
} | |||
} | |||
i += 1; | |||
} | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
@@ -6,14 +7,26 @@ namespace Tensorflow | |||
{ | |||
public class _FetchMapper | |||
{ | |||
public _ElementFetchMapper for_fetch(object fetch) | |||
protected List<object> _unique_fetches = new List<object>(); | |||
public static _FetchMapper for_fetch(object fetch) | |||
{ | |||
var fetches = new object[] { fetch }; | |||
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; | |||
if (fetch.GetType().IsArray) | |||
return new _ListFetchMapper(fetches); | |||
else | |||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||
} | |||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||
{ | |||
return fetched_vals[0]; | |||
}); | |||
public virtual NDArray build_results(List<object> values) | |||
{ | |||
return values.ToArray(); | |||
} | |||
public virtual List<object> unique_fetches() | |||
{ | |||
return _unique_fetches; | |||
} | |||
} | |||
} |
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class _ListFetchMapper : _FetchMapper | |||
{ | |||
private _FetchMapper[] _mappers; | |||
public _ListFetchMapper(object[] fetches) | |||
{ | |||
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); | |||
_unique_fetches.AddRange(fetches); | |||
} | |||
} | |||
} |
@@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples | |||
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | |||
// Minimize error using cross entropy | |||
var log = tf.log(pred); | |||
var mul = y * log; | |||
var sum = tf.reduce_sum(mul, reduction_indices: 1); | |||
var neg = -sum; | |||
var cost = tf.reduce_mean(neg); | |||
var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); | |||
// Gradient Descent | |||
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||
@@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples | |||
{ | |||
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | |||
// Run optimization op (backprop) and cost op (to get loss value) | |||
var (_, c) = sess.run(optimizer, | |||
var result = sess.run(new object[] { optimizer, cost }, | |||
new FeedItem(x, batch_xs), | |||
new FeedItem(y, batch_ys)); | |||
var c = (float)result[1]; | |||
// Compute average loss | |||
avg_cost += c / total_batch; | |||
} | |||
// Display logs per epoch step | |||
if ((epoch + 1) % display_step == 0) | |||
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}"); | |||
} | |||
print("Optimization Finished!"); | |||
// Test model | |||
}); | |||
} | |||
} | |||
@@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility | |||
// Finished epoch | |||
_epochs_completed += 1; | |||
throw new NotImplementedException("next_batch"); | |||
// Get the rest examples in this epoch | |||
var rest_num_examples = _num_examples - start; | |||
var images_rest_part = _images[np.arange(start, _num_examples)]; | |||
var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||
// Shuffle the data | |||
if (shuffle) | |||
{ | |||
var perm = np.arange(_num_examples); | |||
np.random.shuffle(perm); | |||
_images = images[perm]; | |||
_labels = labels[perm]; | |||
} | |||
start = 0; | |||
_index_in_epoch = batch_size - rest_num_examples; | |||
var end = _index_in_epoch; | |||
var images_new_part = _images[np.arange(start, end)]; | |||
var labels_new_part = _labels[np.arange(start, end)]; | |||
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||
return (images_new_part, labels_new_part); | |||
} | |||
else | |||
{ | |||