Browse Source

fix: add the implementation of the tile's grad

tags/v0.150.0-BERT-Model
Wanglongzhi2001 1 year ago
parent
commit
a73694ab2d
3 changed files with 39 additions and 1 deletions
  1. +24
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +14
    -0
      test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+ 24
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -381,5 +381,29 @@ namespace Tensorflow.Gradients
var axis = op.inputs[1]; var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null }; return new Tensor[] { array_ops.reverse(grad, axis), null };
} }

[RegisterGradient("Tile")]
public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
var axes = math_ops.range(0, array_ops.size(split_shape), 2);

//# Sum reduces grad along the first dimension for IndexedSlices
//if isinstance(grad, indexed_slices_lib.IndexedSlices):
//input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
//grad = math_ops.unsorted_segment_sum(
// grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
//split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)

var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
if (!tf.Context.executing_eagerly())
{
input_grad.set_shape(op.inputs[0].GetShape());
}
return new Tensor[] { input_grad, null };

}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -990,7 +990,7 @@ namespace Tensorflow
return @params.sparse_read(indices, name); return @params.sparse_read(indices, name);
} }


public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
{ {
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{ {


+ 14
- 0
test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs View File

@@ -173,5 +173,19 @@ namespace TensorFlowNET.UnitTest.Gradient
var result = grad(x, 4); var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f); Assert.AreEqual((float)result, 4.0f);
} }

[TestMethod]
public void Tile()
{
var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
var b = tf.constant(new int[] { 2 });
using (var tape = tf.GradientTape())
{
tape.watch(a);
var y = tf.tile(a, b);
var grad = tape.gradient(y, a);
Assert.AreEqual((float)grad.numpy(), 2.0f);
}
}
} }
} }

Loading…
Cancel
Save