@@ -10,7 +10,6 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class _ElementFetchMapper : _FetchMapper | public class _ElementFetchMapper : _FetchMapper | ||||
{ | { | ||||
private List<object> _unique_fetches = new List<object>(); | |||||
private Func<List<object>, object> _contraction_fn; | private Func<List<object>, object> _contraction_fn; | ||||
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | ||||
@@ -32,7 +31,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="values"></param> | /// <param name="values"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public NDArray build_results(List<object> values) | |||||
public override NDArray build_results(List<object> values) | |||||
{ | { | ||||
NDArray result = null; | NDArray result = null; | ||||
@@ -51,10 +50,5 @@ namespace Tensorflow | |||||
return result; | return result; | ||||
} | } | ||||
public List<object> unique_fetches() | |||||
{ | |||||
return _unique_fetches; | |||||
} | |||||
} | } | ||||
} | } |
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class _FetchHandler | public class _FetchHandler | ||||
{ | { | ||||
private _ElementFetchMapper _fetch_mapper; | |||||
private _FetchMapper _fetch_mapper; | |||||
private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||||
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | ||||
{ | { | ||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
_fetch_mapper = _FetchMapper.for_fetch(fetches); | |||||
foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
{ | { | ||||
switch (fetch) | switch (fetch) | ||||
@@ -58,7 +58,18 @@ namespace Tensorflow | |||||
{ | { | ||||
var value = tensor_values[j]; | var value = tensor_values[j]; | ||||
j += 1; | j += 1; | ||||
full_values.Add(value); | |||||
switch (value.dtype.Name) | |||||
{ | |||||
case "Int32": | |||||
full_values.Add(value.Data<int>(0)); | |||||
break; | |||||
case "Single": | |||||
full_values.Add(value.Data<float>(0)); | |||||
break; | |||||
case "Double": | |||||
full_values.Add(value.Data<double>(0)); | |||||
break; | |||||
} | |||||
} | } | ||||
i += 1; | i += 1; | ||||
} | } | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
@@ -6,14 +7,26 @@ namespace Tensorflow | |||||
{ | { | ||||
public class _FetchMapper | public class _FetchMapper | ||||
{ | { | ||||
public _ElementFetchMapper for_fetch(object fetch) | |||||
protected List<object> _unique_fetches = new List<object>(); | |||||
public static _FetchMapper for_fetch(object fetch) | |||||
{ | { | ||||
var fetches = new object[] { fetch }; | |||||
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; | |||||
if (fetch.GetType().IsArray) | |||||
return new _ListFetchMapper(fetches); | |||||
else | |||||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||||
} | |||||
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||||
{ | |||||
return fetched_vals[0]; | |||||
}); | |||||
public virtual NDArray build_results(List<object> values) | |||||
{ | |||||
return values.ToArray(); | |||||
} | |||||
public virtual List<object> unique_fetches() | |||||
{ | |||||
return _unique_fetches; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class _ListFetchMapper : _FetchMapper | |||||
{ | |||||
private _FetchMapper[] _mappers; | |||||
public _ListFetchMapper(object[] fetches) | |||||
{ | |||||
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); | |||||
_unique_fetches.AddRange(fetches); | |||||
} | |||||
} | |||||
} |
@@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples | |||||
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | ||||
// Minimize error using cross entropy | // Minimize error using cross entropy | ||||
var log = tf.log(pred); | |||||
var mul = y * log; | |||||
var sum = tf.reduce_sum(mul, reduction_indices: 1); | |||||
var neg = -sum; | |||||
var cost = tf.reduce_mean(neg); | |||||
var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); | |||||
// Gradient Descent | // Gradient Descent | ||||
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | ||||
@@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | ||||
// Run optimization op (backprop) and cost op (to get loss value) | // Run optimization op (backprop) and cost op (to get loss value) | ||||
var (_, c) = sess.run(optimizer, | |||||
var result = sess.run(new object[] { optimizer, cost }, | |||||
new FeedItem(x, batch_xs), | new FeedItem(x, batch_xs), | ||||
new FeedItem(y, batch_ys)); | new FeedItem(y, batch_ys)); | ||||
var c = (float)result[1]; | |||||
// Compute average loss | // Compute average loss | ||||
avg_cost += c / total_batch; | avg_cost += c / total_batch; | ||||
} | } | ||||
// Display logs per epoch step | |||||
if ((epoch + 1) % display_step == 0) | |||||
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}"); | |||||
} | } | ||||
print("Optimization Finished!"); | |||||
// Test model | |||||
}); | }); | ||||
} | } | ||||
} | } | ||||
@@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility | |||||
// Finished epoch | // Finished epoch | ||||
_epochs_completed += 1; | _epochs_completed += 1; | ||||
throw new NotImplementedException("next_batch"); | |||||
// Get the rest examples in this epoch | |||||
var rest_num_examples = _num_examples - start; | |||||
var images_rest_part = _images[np.arange(start, _num_examples)]; | |||||
var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||||
// Shuffle the data | |||||
if (shuffle) | |||||
{ | |||||
var perm = np.arange(_num_examples); | |||||
np.random.shuffle(perm); | |||||
_images = images[perm]; | |||||
_labels = labels[perm]; | |||||
} | |||||
start = 0; | |||||
_index_in_epoch = batch_size - rest_num_examples; | |||||
var end = _index_in_epoch; | |||||
var images_new_part = _images[np.arange(start, end)]; | |||||
var labels_new_part = _labels[np.arange(start, end)]; | |||||
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||||
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||||
return (images_new_part, labels_new_part); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||