|
|
@@ -168,9 +168,38 @@ namespace Tensorflow.Gradients |
|
|
|
return new Tensor[] { grad_a, grad_b }; |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("BatchMatMul")] |
|
|
|
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
var grad = grads[0]; |
|
|
|
Tensor grad_a = null, grad_b = null; |
|
|
|
|
|
|
|
var t_a = (bool)op.get_attr("adj_x"); |
|
|
|
var t_b = (bool)op.get_attr("adj_y"); |
|
|
|
var a = math_ops.conj(op.inputs[0]); |
|
|
|
var b = math_ops.conj(op.inputs[1]); |
|
|
|
if (!t_a && !t_b) |
|
|
|
{ |
|
|
|
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true); |
|
|
|
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true); |
|
|
|
} |
|
|
|
else if (!t_a && t_b) |
|
|
|
{ |
|
|
|
grad_a = gen_math_ops.batch_mat_mul(grad, b); |
|
|
|
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); |
|
|
|
} |
|
|
|
else if (t_a && !t_b) |
|
|
|
{ |
|
|
|
grad_a = gen_math_ops.batch_mat_mul(grad, b); |
|
|
|
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); |
|
|
|
} |
|
|
|
else if (t_a && t_b) |
|
|
|
{ |
|
|
|
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true); |
|
|
|
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true); |
|
|
|
} |
|
|
|
|
|
|
|
return new Tensor[] { grad_a, grad_b }; |
|
|
|
} |
|
|
|
|
|
|
|
[RegisterGradient("Mean")] |
|
|
|