Browse Source

Adding `BatchMatMul` gradient (#304)

Unit testing the gradient too.
tags/v0.10
Antonio Haiping 6 years ago
parent
commit
c616eea1e8
2 changed files with 65 additions and 1 deletions
  1. +30
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +35
    -0
      test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs

+ 30
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -168,9 +168,38 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad_a, grad_b };
}

[RegisterGradient("BatchMatMul")]
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads)
{
throw new NotImplementedException();
var grad = grads[0];
Tensor grad_a = null, grad_b = null;

var t_a = (bool)op.get_attr("adj_x");
var t_b = (bool)op.get_attr("adj_y");
var a = math_ops.conj(op.inputs[0]);
var b = math_ops.conj(op.inputs[1]);
if (!t_a && !t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true);
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true);
}
else if (!t_a && t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
}
else if (t_a && !t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
}
else if (t_a && t_b)
{
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true);
}

return new Tensor[] { grad_a, grad_b };
}

[RegisterGradient("Mean")]


+ 35
- 0
test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs View File

@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Python;
@@ -30,6 +32,39 @@ namespace TensorFlowNET.UnitTest.gradients_test
});
}
[TestMethod]
public void testBatchMatMulGradient()
{
var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape:new []{2, 3, 3});
var b = tf.divide(a, tf.constant(2.0f));
var c = tf.batch_matmul(a, b);
var g = tf.gradients(c, new[] {a, b}, stop_gradients: new[] {a, b});
var checkG = new[]
{
3.0f, 7.5f, 12.0f,
3.0f, 7.5f, 12.0f,
3.0f, 7.5f, 12.0f,
16.5f, 21.0f, 25.5f,
16.5f, 21.0f, 25.5f,
16.5f, 21.0f, 25.5f,
12.0f, 12.0f, 12.0f,
15.0f, 15.0f, 15.0f,
18.0f, 18.0f, 18.0f,
39.0f, 39.0f, 39.0f,
42.0f, 42.0f, 42.0f,
45.0f, 45.0f, 45.0f
};
using (var sess = tf.Session())
{
var result = sess.run(g);
var resultList = result[0].GetData<float>().ToList();
resultList.AddRange(result[1].GetData<float>());
Console.WriteLine(result.ToString());
CollectionAssert.AreEqual(resultList.ToArray(), checkG);
}
}
[Ignore("TODO")]
[TestMethod]
public void testUnusedOutput()


Loading…
Cancel
Save