Browse Source

fix: add the implementation of GatherND's grad

tags/v0.150.0-BERT-Model
Wanglongzhi2001 1 year ago
parent
commit
d0ec6591a0
4 changed files with 46 additions and 2 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  4. +16
    -1
      test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+ 10
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -140,6 +140,16 @@ namespace Tensorflow
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));

/// <summary>
/// Gather slices from `params` into a Tensor with shape specified by `indices`.
/// </summary>
/// <param name="params"></param>
/// <param name="indices"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
=> gen_array_ops.gather_nd(@params, indices, name: name);

/// <summary>
/// Return the elements, either from `x` or `y`, depending on the `condition`.
/// </summary>


+ 19
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -403,7 +403,26 @@ namespace Tensorflow.Gradients
input_grad.set_shape(op.inputs[0].GetShape());
}
return new Tensor[] { input_grad, null };
}

[RegisterGradient("GatherNd")]
public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads)
{
var @ref = op.inputs[0];
var indices = op.inputs[1];
var grad = grads[0];
var ref_shape = array_ops.shape(@ref, out_type: indices.dtype);
Tensor ref_grad = null;
if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1)
{
ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape);
}
else
{
ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape);
}
return new Tensor[] { ref_grad, null };
}

}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -829,7 +829,7 @@ namespace Tensorflow
/// <returns>A `Tensor`. Has the same type as `input`.
/// Contains the same data as `input`, but has one or more dimensions of
/// size 1 removed.</returns>
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
public static Tensor squeeze(Tensor input, Axis axis = null, string name = null)
=> gen_array_ops.squeeze(input, axis, name);

public static Tensor identity(Tensor input, string name = null)


+ 16
- 1
test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs View File

@@ -62,7 +62,7 @@ namespace TensorFlowNET.UnitTest.Gradient
// Calcute the gradient of (x1-x2)^2
// by Automatic Differentiation in Eager mode
// Expected is 2*(abs(x1-x2))
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
float[] expected = new float[]
{
@@ -187,5 +187,20 @@ namespace TensorFlowNET.UnitTest.Gradient
Assert.AreEqual((float)grad.numpy(), 2.0f);
}
}

[TestMethod]
public void GatherNdTest()
{
var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT);
var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32);
using (var tape = tf.GradientTape())
{
tape.watch(x);
var res = tf.gather_nd(x, indices);
var grad = tape.gradient(res, x);
var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } });
Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray<float>(), expected.ToArray<float>()));
}
}
}
}

Loading…
Cancel
Save