diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 9ab64848..e5e99862 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -168,9 +168,38 @@ namespace Tensorflow.Gradients return new Tensor[] { grad_a, grad_b }; } + [RegisterGradient("BatchMatMul")] public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) { - throw new NotImplementedException(); + var grad = grads[0]; + Tensor grad_a = null, grad_b = null; + + var t_a = (bool)op.get_attr("adj_x"); + var t_b = (bool)op.get_attr("adj_y"); + var a = math_ops.conj(op.inputs[0]); + var b = math_ops.conj(op.inputs[1]); + if (!t_a && !t_b) + { + grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true); + grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true); + } + else if (!t_a && t_b) + { + grad_a = gen_math_ops.batch_mat_mul(grad, b); + grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); + } + else if (t_a && !t_b) + { + grad_a = gen_math_ops.batch_mat_mul(grad, b); + grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); + } + else if (t_a && t_b) + { + grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true); + grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true); + } + + return new Tensor[] { grad_a, grad_b }; } [RegisterGradient("Mean")] diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs index 13182a68..67a84d48 100644 --- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using Tensorflow; using static Tensorflow.Python; @@ -30,6 +32,39 @@ namespace TensorFlowNET.UnitTest.gradients_test }); } + [TestMethod] + public void testBatchMatMulGradient() + { + var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape:new []{2, 3, 3}); + var b = tf.divide(a, tf.constant(2.0f)); + var c = tf.batch_matmul(a, b); + var g = tf.gradients(c, new[] {a, b}, stop_gradients: new[] {a, b}); + var checkG = new[] + { + 3.0f, 7.5f, 12.0f, + 3.0f, 7.5f, 12.0f, + 3.0f, 7.5f, 12.0f, + 16.5f, 21.0f, 25.5f, + 16.5f, 21.0f, 25.5f, + 16.5f, 21.0f, 25.5f, + 12.0f, 12.0f, 12.0f, + 15.0f, 15.0f, 15.0f, + 18.0f, 18.0f, 18.0f, + 39.0f, 39.0f, 39.0f, + 42.0f, 42.0f, 42.0f, + 45.0f, 45.0f, 45.0f + }; + using (var sess = tf.Session()) + { + var result = sess.run(g); + var resultList = result[0].GetData().ToList(); + resultList.AddRange(result[1].GetData()); + Console.WriteLine(result.ToString()); + CollectionAssert.AreEqual(resultList.ToArray(), checkG); + } + } + + [Ignore("TODO")] [TestMethod] public void testUnusedOutput()