diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 4380a7fa..d47f9732 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -22,8 +22,10 @@ namespace Tensorflow.Gradients var sy = array_ops.shape(y); var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); - var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); - var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy); + var sum1 = math_ops.reduce_sum(grad, rx); + var r1 = gen_array_ops.reshape(sum1, sx); + var sum2 = math_ops.reduce_sum(grad, ry); + var r2 = gen_array_ops.reshape(sum2, sy); return new Tensor[] { r1, r2 }; } @@ -48,7 +50,8 @@ namespace Tensorflow.Gradients var x = op.inputs[0]; var y = op.inputs[1]; var grad = grads[0]; - if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) && + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad) && new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) }; @@ -60,10 +63,11 @@ namespace Tensorflow.Gradients y = math_ops.conj(y); var mul1 = gen_math_ops.mul(grad, y); - var mul2 = gen_math_ops.mul(x, grad); var reduce_sum1 = math_ops.reduce_sum(mul1, rx); - var reduce_sum2 = math_ops.reduce_sum(mul2, ry); var reshape1 = gen_array_ops.reshape(reduce_sum1, sx); + + var mul2 = gen_math_ops.mul(x, grad); + var reduce_sum2 = math_ops.reduce_sum(mul2, ry); var reshape2 = gen_array_ops.reshape(reduce_sum2, sy); return new Tensor[] { reshape1, reshape2 }; @@ -146,7 +150,13 @@ namespace Tensorflow.Gradients public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) { - return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1; + var x_shape = x._shape_tuple(); + var y_shape = y._shape_tuple(); + var grad_shape = grad._shape_tuple(); + return Enumerable.SequenceEqual(x_shape, y_shape) && + Enumerable.SequenceEqual(y_shape, grad_shape) && + x.NDims != -1 && + !x_shape.Contains(-1); } public static Tensor[] _SumGrad(Operation op, Tensor[] grads) diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index b0382553..4888ec88 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; using System.Text; namespace Tensorflow @@ -16,6 +17,11 @@ namespace Tensorflow Console.WriteLine(obj.ToString()); } + protected IEnumerable range(int end) + { + return Enumerable.Range(0, end); + } + public static T New(object args) where T : IPyClass { var instance = Activator.CreateInstance(); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 5d037696..bf85284a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -43,6 +43,8 @@ namespace Tensorflow public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); + private TF_Output? _tf_output; + public long[] shape { get @@ -123,7 +125,10 @@ namespace Tensorflow public TF_Output _as_tf_output() { - return new TF_Output(op, value_index); + if(!_tf_output.HasValue) + _tf_output = new TF_Output(op, value_index); + + return _tf_output.Value; } public T[] Data() diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 6cee6d8b..4885f633 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -1,6 +1,8 @@ -using NumSharp.Core; +using Newtonsoft.Json; +using NumSharp.Core; using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow; using TensorFlowNET.Examples.Utility; @@ -26,8 +28,6 @@ namespace TensorFlowNET.Examples private void PrepareData() { - //var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); - // tf Graph Input var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes @@ -49,13 +49,37 @@ namespace TensorFlowNET.Examples // Gradient Descent var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); + //var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin"); + + /*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings + { + Formatting = Formatting.Indented + });*/ + // Initialize the variables (i.e. assign their default value) var init = tf.global_variables_initializer(); with(tf.Session(), sess => { + var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); // Run the initializer sess.run(init); + + // Training cycle + foreach(var epoch in range(training_epochs)) + { + var avg_cost = 0.0f; + var total_batch = (int)(mnist.train.num_examples / batch_size); + // Loop over all batches + foreach (var i in range(total_batch)) + { + var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); + // Run optimization op (backprop) and cost op (to get loss value) + /*sess.run(optimizer, + new FeedItem(x, batch_xs), + new FeedItem(y, batch_ys));*/ + } + } }); } } diff --git a/test/TensorFlowNET.Examples/Utility/DataSet.cs b/test/TensorFlowNET.Examples/Utility/DataSet.cs index 7ace7b94..0552905f 100644 --- a/test/TensorFlowNET.Examples/Utility/DataSet.cs +++ b/test/TensorFlowNET.Examples/Utility/DataSet.cs @@ -9,10 +9,15 @@ namespace TensorFlowNET.Examples.Utility public class DataSet { private int _num_examples; + public int num_examples => _num_examples; private int _epochs_completed; + public int epochs_completed => _epochs_completed; private int _index_in_epoch; + public int index_in_epoch => _index_in_epoch; private NDArray _images; + public NDArray images => _images; private NDArray _labels; + public NDArray labels => _labels; public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) { @@ -26,5 +31,33 @@ namespace TensorFlowNET.Examples.Utility _epochs_completed = 0; _index_in_epoch = 0; } + + public (int, int) next_batch(int batch_size, bool fake_data = false, bool shuffle = true) + { + var start = _index_in_epoch; + // Shuffle for the first epoch + if(_epochs_completed == 0 && start == 0 && shuffle) + { + var perm0 = np.arange(_num_examples); + np.random.shuffle(perm0); + _images = images[perm0]; + _labels = labels[perm0]; + } + + // Go to the next epoch + if (start + batch_size > _num_examples) + { + // Finished epoch + _epochs_completed += 1; + + throw new NotImplementedException("next_batch"); + } + else + { + _index_in_epoch += batch_size; + var end = _index_in_epoch; + return (_images[np.arange(start, end)], _labels[np.arange(start, end)]); + } + } } }