Browse Source

_ShapesFullySpecifiedAndEqual

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
369d8abd7a
5 changed files with 88 additions and 10 deletions
  1. +16
    -6
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +6
    -0
      src/TensorFlowNET.Core/Python.cs
  3. +6
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +27
    -3
      test/TensorFlowNET.Examples/LogisticRegression.cs
  5. +33
    -0
      test/TensorFlowNET.Examples/Utility/DataSet.cs

+ 16
- 6
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -22,8 +22,10 @@ namespace Tensorflow.Gradients
var sy = array_ops.shape(y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);

var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy);
var sum1 = math_ops.reduce_sum(grad, rx);
var r1 = gen_array_ops.reshape(sum1, sx);
var sum2 = math_ops.reduce_sum(grad, ry);
var r2 = gen_array_ops.reshape(sum2, sy);

return new Tensor[] { r1, r2 };
}
@@ -48,7 +50,8 @@ namespace Tensorflow.Gradients
var x = op.inputs[0];
var y = op.inputs[1];
var grad = grads[0];
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) &&
if (grad is Tensor &&
_ShapesFullySpecifiedAndEqual(x, y, grad) &&
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype))
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) };

@@ -60,10 +63,11 @@ namespace Tensorflow.Gradients
y = math_ops.conj(y);

var mul1 = gen_math_ops.mul(grad, y);
var mul2 = gen_math_ops.mul(x, grad);
var reduce_sum1 = math_ops.reduce_sum(mul1, rx);
var reduce_sum2 = math_ops.reduce_sum(mul2, ry);
var reshape1 = gen_array_ops.reshape(reduce_sum1, sx);

var mul2 = gen_math_ops.mul(x, grad);
var reduce_sum2 = math_ops.reduce_sum(mul2, ry);
var reshape2 = gen_array_ops.reshape(reduce_sum2, sy);

return new Tensor[] { reshape1, reshape2 };
@@ -146,7 +150,13 @@ namespace Tensorflow.Gradients

public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
{
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1;
var x_shape = x._shape_tuple();
var y_shape = y._shape_tuple();
var grad_shape = grad._shape_tuple();
return Enumerable.SequenceEqual(x_shape, y_shape) &&
Enumerable.SequenceEqual(y_shape, grad_shape) &&
x.NDims != -1 &&
!x_shape.Contains(-1);
}

public static Tensor[] _SumGrad(Operation op, Tensor[] grads)


+ 6
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text;

namespace Tensorflow
@@ -16,6 +17,11 @@ namespace Tensorflow
Console.WriteLine(obj.ToString());
}

protected IEnumerable<int> range(int end)
{
return Enumerable.Range(0, end);
}

public static T New<T>(object args) where T : IPyClass
{
var instance = Activator.CreateInstance<T>();


+ 6
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -43,6 +43,8 @@ namespace Tensorflow
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);

private TF_Output? _tf_output;

public long[] shape
{
get
@@ -123,7 +125,10 @@ namespace Tensorflow

public TF_Output _as_tf_output()
{
return new TF_Output(op, value_index);
if(!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);

return _tf_output.Value;
}

public T[] Data<T>()


+ 27
- 3
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -1,6 +1,8 @@
using NumSharp.Core;
using Newtonsoft.Json;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
@@ -26,8 +28,6 @@ namespace TensorFlowNET.Examples

private void PrepareData()
{
//var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);

// tf Graph Input
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -49,13 +49,37 @@ namespace TensorFlowNET.Examples
// Gradient Descent
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);

//var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin");

/*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings
{
Formatting = Formatting.Indented
});*/

// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();

with(tf.Session(), sess =>
{
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
// Run the initializer
sess.run(init);

// Training cycle
foreach(var epoch in range(training_epochs))
{
var avg_cost = 0.0f;
var total_batch = (int)(mnist.train.num_examples / batch_size);
// Loop over all batches
foreach (var i in range(total_batch))
{
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
// Run optimization op (backprop) and cost op (to get loss value)
/*sess.run(optimizer,
new FeedItem(x, batch_xs),
new FeedItem(y, batch_ys));*/
}
}
});
}
}


+ 33
- 0
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -9,10 +9,15 @@ namespace TensorFlowNET.Examples.Utility
public class DataSet
{
private int _num_examples;
public int num_examples => _num_examples;
private int _epochs_completed;
public int epochs_completed => _epochs_completed;
private int _index_in_epoch;
public int index_in_epoch => _index_in_epoch;
private NDArray _images;
public NDArray images => _images;
private NDArray _labels;
public NDArray labels => _labels;

public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
{
@@ -26,5 +31,33 @@ namespace TensorFlowNET.Examples.Utility
_epochs_completed = 0;
_index_in_epoch = 0;
}

public (int, int) next_batch(int batch_size, bool fake_data = false, bool shuffle = true)
{
var start = _index_in_epoch;
// Shuffle for the first epoch
if(_epochs_completed == 0 && start == 0 && shuffle)
{
var perm0 = np.arange(_num_examples);
np.random.shuffle(perm0);
_images = images[perm0];
_labels = labels[perm0];
}

// Go to the next epoch
if (start + batch_size > _num_examples)
{
// Finished epoch
_epochs_completed += 1;

throw new NotImplementedException("next_batch");
}
else
{
_index_in_epoch += batch_size;
var end = _index_in_epoch;
return (_images[np.arange(start, end)], _labels[np.arange(start, end)]);
}
}
}
}

Loading…
Cancel
Save