Browse Source

tf.reduce_sum #917

tags/v0.70.2-NET6
Oceania2018 3 years ago
parent
commit
427a2f9f1e
3 changed files with 30 additions and 7 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +8
    -3
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +16
    -3
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

+ 6
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -529,7 +529,12 @@ namespace Tensorflow.Gradients
}
else if (!input_0_shape.Contains(-1) && !tf.Context.executing_eagerly())
{
throw new NotImplementedException("");
axes = axes.reshape(new Shape(-1));
var shape_tensor = tf.constant(op.inputs[0].shape.as_int_list());
var output_shape_kept_dims = math_ops.reduced_shape(shape_tensor, axes);
var tile_scaling = _safe_shape_div(shape_tensor, output_shape_kept_dims);
grad = array_ops.reshape(grad, output_shape_kept_dims);
return new Tensor[] { array_ops.tile(grad, tile_scaling), null };
}
}
}


+ 8
- 3
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -585,9 +585,14 @@ namespace Tensorflow
}

public static Tensor tile(Tensor input, Tensor multiples, string name = null)
{
throw new NotImplementedException("tile");
}
=> tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
}
});

public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{


+ 16
- 3
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -178,6 +178,19 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod]
public void testReduceSumGradients()
{
/* python code
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x = tf.placeholder(tf.float64, shape = (1, 1))
m = tf.broadcast_to(x, (2, 3))
g0 = tf.gradients(tf.reduce_sum(m), x)[0]
g1 = tf.gradients(tf.reduce_sum(m, axis = 0), x)[0]
g2 = tf.gradients(tf.reduce_sum(m, axis = 1), x)[0]
with tf.compat.v1.Session() as sess:
(r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]})
*/

var x = tf.placeholder(tf.float64, shape: new Shape(1, 1));
var m = tf.broadcast_to(x, new Shape(2, 3));
var g0 = tf.gradients(tf.reduce_sum(m), x)[0];
@@ -186,10 +199,10 @@ namespace TensorFlowNET.UnitTest.Gradient

using (var session = tf.Session())
{
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, 1.0));
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
self.assertFloat64Equal(6.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
self.assertFloat64Equal(6.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
}
}



Loading…
Cancel
Save