@@ -168,9 +168,38 @@ namespace Tensorflow.Gradients | |||
return new Tensor[] { grad_a, grad_b }; | |||
} | |||
[RegisterGradient("BatchMatMul")] | |||
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) | |||
{ | |||
throw new NotImplementedException(); | |||
var grad = grads[0]; | |||
Tensor grad_a = null, grad_b = null; | |||
var t_a = (bool)op.get_attr("adj_x"); | |||
var t_b = (bool)op.get_attr("adj_y"); | |||
var a = math_ops.conj(op.inputs[0]); | |||
var b = math_ops.conj(op.inputs[1]); | |||
if (!t_a && !t_b) | |||
{ | |||
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true); | |||
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true); | |||
} | |||
else if (!t_a && t_b) | |||
{ | |||
grad_a = gen_math_ops.batch_mat_mul(grad, b); | |||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); | |||
} | |||
else if (t_a && !t_b) | |||
{ | |||
grad_a = gen_math_ops.batch_mat_mul(grad, b); | |||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); | |||
} | |||
else if (t_a && t_b) | |||
{ | |||
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true); | |||
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true); | |||
} | |||
return new Tensor[] { grad_a, grad_b }; | |||
} | |||
[RegisterGradient("Mean")] | |||
@@ -1,7 +1,9 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using static Tensorflow.Python; | |||
@@ -30,6 +32,39 @@ namespace TensorFlowNET.UnitTest.gradients_test | |||
}); | |||
} | |||
[TestMethod] | |||
public void testBatchMatMulGradient() | |||
{ | |||
var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape:new []{2, 3, 3}); | |||
var b = tf.divide(a, tf.constant(2.0f)); | |||
var c = tf.batch_matmul(a, b); | |||
var g = tf.gradients(c, new[] {a, b}, stop_gradients: new[] {a, b}); | |||
var checkG = new[] | |||
{ | |||
3.0f, 7.5f, 12.0f, | |||
3.0f, 7.5f, 12.0f, | |||
3.0f, 7.5f, 12.0f, | |||
16.5f, 21.0f, 25.5f, | |||
16.5f, 21.0f, 25.5f, | |||
16.5f, 21.0f, 25.5f, | |||
12.0f, 12.0f, 12.0f, | |||
15.0f, 15.0f, 15.0f, | |||
18.0f, 18.0f, 18.0f, | |||
39.0f, 39.0f, 39.0f, | |||
42.0f, 42.0f, 42.0f, | |||
45.0f, 45.0f, 45.0f | |||
}; | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(g); | |||
var resultList = result[0].GetData<float>().ToList(); | |||
resultList.AddRange(result[1].GetData<float>()); | |||
Console.WriteLine(result.ToString()); | |||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||
} | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testUnusedOutput() | |||