|
@@ -117,6 +117,137 @@ namespace Tensorflow.Gradients |
|
|
}; |
|
|
}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static string ellipsis = "..."; |
|
|
|
|
|
[RegisterGradient("Einsum")] |
|
|
|
|
|
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) |
|
|
|
|
|
{ |
|
|
|
|
|
// Gradient for Einsum. |
|
|
|
|
|
string equation = (string)op.get_attr("equation"); |
|
|
|
|
|
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); |
|
|
|
|
|
var input_subs = split_equation[0]; |
|
|
|
|
|
var output_subs = split_equation[1]; |
|
|
|
|
|
|
|
|
|
|
|
if (op.inputs.Length == 1) |
|
|
|
|
|
{ |
|
|
|
|
|
var input_shape = array_ops.shape(op.inputs[0]); |
|
|
|
|
|
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); |
|
|
|
|
|
if (reduced_label_set.Count == 0) |
|
|
|
|
|
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; |
|
|
|
|
|
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); |
|
|
|
|
|
var x_subs = split_input_subs[0]; |
|
|
|
|
|
var y_subs = split_input_subs[1]; |
|
|
|
|
|
// Add ellipsis for broadcasted dimensions if any operand does not have it. |
|
|
|
|
|
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's |
|
|
|
|
|
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid |
|
|
|
|
|
// because only the output subscripts contain ellipsis. |
|
|
|
|
|
if (output_subs.Contains(ellipsis)) |
|
|
|
|
|
{ |
|
|
|
|
|
if (!x_subs.Contains(ellipsis)) |
|
|
|
|
|
x_subs += ellipsis; |
|
|
|
|
|
if (!y_subs.Contains(ellipsis)) |
|
|
|
|
|
y_subs += ellipsis; |
|
|
|
|
|
} |
|
|
|
|
|
// Obtain the gradients wrt the inputs x and y, without taking into account |
|
|
|
|
|
// the unbroadcasting. |
|
|
|
|
|
var x = op.inputs[0]; |
|
|
|
|
|
var y = op.inputs[1]; |
|
|
|
|
|
if (grads.GetDataType().is_complex()) |
|
|
|
|
|
{ |
|
|
|
|
|
x = math_ops.conj(x); |
|
|
|
|
|
y = math_ops.conj(y); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var x_shape = array_ops.shape(x); |
|
|
|
|
|
var y_shape = array_ops.shape(y); |
|
|
|
|
|
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); |
|
|
|
|
|
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); |
|
|
|
|
|
|
|
|
|
|
|
if (!output_subs.Contains(ellipsis)) |
|
|
|
|
|
return new Tensor[] { grad_x, grad_y }; |
|
|
|
|
|
var bx = _GetBcastSubshape(x_subs); |
|
|
|
|
|
int bx_start = bx[0], bx_end = bx[1]; |
|
|
|
|
|
var by = _GetBcastSubshape(y_subs); |
|
|
|
|
|
int by_start = by[0], by_end = by[1]; |
|
|
|
|
|
|
|
|
|
|
|
var x_shape_static = x.shape; |
|
|
|
|
|
var y_shape_static = y.shape; |
|
|
|
|
|
if(x_shape_static.IsFullyDefined && |
|
|
|
|
|
y_shape_static.IsFullyDefined && |
|
|
|
|
|
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) |
|
|
|
|
|
return new Tensor[] { grad_x, grad_y }; |
|
|
|
|
|
|
|
|
|
|
|
var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], |
|
|
|
|
|
y_shape[string.Format("{0}:{1}", by_start, by_end)]); |
|
|
|
|
|
var rx = r[0]; |
|
|
|
|
|
var ry = r[1]; |
|
|
|
|
|
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); |
|
|
|
|
|
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); |
|
|
|
|
|
return new Tensor[] { grad_x, grad_y }; |
|
|
|
|
|
} |
|
|
|
|
|
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, |
|
|
|
|
|
string input_subs, string other_subs, string output_subs) |
|
|
|
|
|
{ |
|
|
|
|
|
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); |
|
|
|
|
|
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); |
|
|
|
|
|
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); |
|
|
|
|
|
if (reduced_label_set.Count == 0) |
|
|
|
|
|
return grad_reduced; |
|
|
|
|
|
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); |
|
|
|
|
|
} |
|
|
|
|
|
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) |
|
|
|
|
|
{ |
|
|
|
|
|
string reduced_subs; |
|
|
|
|
|
Tensor reduced_dims; |
|
|
|
|
|
List<int> reduced_axes; |
|
|
|
|
|
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); |
|
|
|
|
|
bool has_repeated_labels = ( |
|
|
|
|
|
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < |
|
|
|
|
|
input_subs.Length + output_subs.Length); |
|
|
|
|
|
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); |
|
|
|
|
|
|
|
|
|
|
|
if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) |
|
|
|
|
|
{ |
|
|
|
|
|
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); |
|
|
|
|
|
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); |
|
|
|
|
|
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); |
|
|
|
|
|
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); |
|
|
|
|
|
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) |
|
|
|
|
|
{ |
|
|
|
|
|
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); |
|
|
|
|
|
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); |
|
|
|
|
|
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); |
|
|
|
|
|
} |
|
|
|
|
|
protected static int _GetAxisFromLabel(string subscripts, char label) |
|
|
|
|
|
{ |
|
|
|
|
|
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); |
|
|
|
|
|
var index = splits[0].IndexOf(label); |
|
|
|
|
|
if (index != -1) return index; |
|
|
|
|
|
if (splits.Length < 2) throw new OutOfRangeError(); |
|
|
|
|
|
index = splits[1].IndexOf(label); |
|
|
|
|
|
if (index != -1) return index; |
|
|
|
|
|
throw new ValueError(); |
|
|
|
|
|
} |
|
|
|
|
|
protected static int[] _GetBcastSubshape(string subscripts) |
|
|
|
|
|
{ |
|
|
|
|
|
int start = subscripts.IndexOf(ellipsis); |
|
|
|
|
|
if (start == -1) return new int[] { 0, 0 }; |
|
|
|
|
|
int remaining = subscripts.Length - (start + ellipsis.Length); |
|
|
|
|
|
int end; |
|
|
|
|
|
if (remaining > 0) end = remaining; |
|
|
|
|
|
else throw new Exception(); |
|
|
|
|
|
return new int[] { start, end }; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
|
/// Returns grad * exp(x). |
|
|
/// Returns grad * exp(x). |
|
|
/// </summary> |
|
|
/// </summary> |
|
|