Browse Source

add EinsumGrad

tags/v0.110.4-Transformer-Model
lingbai-kong 2 years ago
parent
commit
f026963a7d
1 changed files with 131 additions and 0 deletions
  1. +131
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs

+ 131
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -117,6 +117,137 @@ namespace Tensorflow.Gradients
};
}

public static string ellipsis = "...";
[RegisterGradient("Einsum")]
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads)
{
// Gradient for Einsum.
string equation = (string)op.get_attr("equation");
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None);
var input_subs = split_equation[0];
var output_subs = split_equation[1];

if (op.inputs.Length == 1)
{
var input_shape = array_ops.shape(op.inputs[0]);
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis)));
if (reduced_label_set.Count == 0)
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) };
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) };
}

string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None);
var x_subs = split_input_subs[0];
var y_subs = split_input_subs[1];
// Add ellipsis for broadcasted dimensions if any operand does not have it.
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
// because only the output subscripts contain ellipsis.
if (output_subs.Contains(ellipsis))
{
if (!x_subs.Contains(ellipsis))
x_subs += ellipsis;
if (!y_subs.Contains(ellipsis))
y_subs += ellipsis;
}
// Obtain the gradients wrt the inputs x and y, without taking into account
// the unbroadcasting.
var x = op.inputs[0];
var y = op.inputs[1];
if (grads.GetDataType().is_complex())
{
x = math_ops.conj(x);
y = math_ops.conj(y);
}

var x_shape = array_ops.shape(x);
var y_shape = array_ops.shape(y);
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs);
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs);

if (!output_subs.Contains(ellipsis))
return new Tensor[] { grad_x, grad_y };
var bx = _GetBcastSubshape(x_subs);
int bx_start = bx[0], bx_end = bx[1];
var by = _GetBcastSubshape(y_subs);
int by_start = by[0], by_end = by[1];

var x_shape_static = x.shape;
var y_shape_static = y.shape;
if(x_shape_static.IsFullyDefined &&
y_shape_static.IsFullyDefined &&
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)])
return new Tensor[] { grad_x, grad_y };

var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)],
y_shape[string.Format("{0}:{1}", by_start, by_end)]);
var rx = r[0];
var ry = r[1];
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape);
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape);
return new Tensor[] { grad_x, grad_y };
}
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape,
string input_subs, string other_subs, string output_subs)
{
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + ".")));
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand));
if (reduced_label_set.Count == 0)
return grad_reduced;
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set);
}
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set)
{
string reduced_subs;
Tensor reduced_dims;
List<int> reduced_axes;
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes);
bool has_repeated_labels = (
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count <
input_subs.Length + output_subs.Length);
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));

if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs)
{
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes));
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape);
}
else
{
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels);
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad));
}
}
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes)
{
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString()));
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList();
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList());
}
protected static int _GetAxisFromLabel(string subscripts, char label)
{
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None);
var index = splits[0].IndexOf(label);
if (index != -1) return index;
if (splits.Length < 2) throw new OutOfRangeError();
index = splits[1].IndexOf(label);
if (index != -1) return index;
throw new ValueError();
}
protected static int[] _GetBcastSubshape(string subscripts)
{
int start = subscripts.IndexOf(ellipsis);
if (start == -1) return new int[] { 0, 0 };
int remaining = subscripts.Length - (start + ellipsis.Length);
int end;
if (remaining > 0) end = remaining;
else throw new Exception();
return new int[] { start, end };
}

/// <summary>
/// Returns grad * exp(x).
/// </summary>


Loading…
Cancel
Save