@@ -168,9 +168,38 @@ namespace Tensorflow.Gradients | |||||
return new Tensor[] { grad_a, grad_b }; | return new Tensor[] { grad_a, grad_b }; | ||||
} | } | ||||
[RegisterGradient("BatchMatMul")] | |||||
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) | public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
var grad = grads[0]; | |||||
Tensor grad_a = null, grad_b = null; | |||||
var t_a = (bool)op.get_attr("adj_x"); | |||||
var t_b = (bool)op.get_attr("adj_y"); | |||||
var a = math_ops.conj(op.inputs[0]); | |||||
var b = math_ops.conj(op.inputs[1]); | |||||
if (!t_a && !t_b) | |||||
{ | |||||
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true); | |||||
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true); | |||||
} | |||||
else if (!t_a && t_b) | |||||
{ | |||||
grad_a = gen_math_ops.batch_mat_mul(grad, b); | |||||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); | |||||
} | |||||
else if (t_a && !t_b) | |||||
{ | |||||
grad_a = gen_math_ops.batch_mat_mul(grad, b); | |||||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); | |||||
} | |||||
else if (t_a && t_b) | |||||
{ | |||||
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true); | |||||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true); | |||||
} | |||||
return new Tensor[] { grad_a, grad_b }; | |||||
} | } | ||||
[RegisterGradient("Mean")] | [RegisterGradient("Mean")] | ||||
@@ -1,7 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
@@ -30,6 +32,39 @@ namespace TensorFlowNET.UnitTest.gradients_test | |||||
}); | }); | ||||
} | } | ||||
[TestMethod] | |||||
public void testBatchMatMulGradient() | |||||
{ | |||||
var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape:new []{2, 3, 3}); | |||||
var b = tf.divide(a, tf.constant(2.0f)); | |||||
var c = tf.batch_matmul(a, b); | |||||
var g = tf.gradients(c, new[] {a, b}, stop_gradients: new[] {a, b}); | |||||
var checkG = new[] | |||||
{ | |||||
3.0f, 7.5f, 12.0f, | |||||
3.0f, 7.5f, 12.0f, | |||||
3.0f, 7.5f, 12.0f, | |||||
16.5f, 21.0f, 25.5f, | |||||
16.5f, 21.0f, 25.5f, | |||||
16.5f, 21.0f, 25.5f, | |||||
12.0f, 12.0f, 12.0f, | |||||
15.0f, 15.0f, 15.0f, | |||||
18.0f, 18.0f, 18.0f, | |||||
39.0f, 39.0f, 39.0f, | |||||
42.0f, 42.0f, 42.0f, | |||||
45.0f, 45.0f, 45.0f | |||||
}; | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(g); | |||||
var resultList = result[0].GetData<float>().ToList(); | |||||
resultList.AddRange(result[1].GetData<float>()); | |||||
Console.WriteLine(result.ToString()); | |||||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||||
} | |||||
} | |||||
[Ignore("TODO")] | [Ignore("TODO")] | ||||
[TestMethod] | [TestMethod] | ||||
public void testUnusedOutput() | public void testUnusedOutput() | ||||