@@ -6,9 +6,9 @@ namespace Tensorflow.Gradients | |||
{ | |||
public class array_grad | |||
{ | |||
public static (Tensor, Tensor) _ReshapeGrad(Operation op, Tensor grad) | |||
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) | |||
{ | |||
return (array_ops.reshape(grad, array_ops.shape(op.inputs[0])), null); | |||
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | |||
} | |||
} | |||
} |
@@ -136,7 +136,7 @@ namespace Tensorflow | |||
string name1 = scope1; | |||
if (grad_fn != null) | |||
{ | |||
in_grads = _MaybeCompile(grad_scope, op, out_grads[0], null, grad_fn); | |||
in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); | |||
_VerifyGeneratedGradients(in_grads, op); | |||
} | |||
@@ -226,7 +226,7 @@ namespace Tensorflow | |||
$"inputs {op.inputs._inputs.Count()}"); | |||
} | |||
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, Tensor[]> grad_fn) | |||
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
{ | |||
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
return grad_fn(op, out_grads); | |||
@@ -10,12 +10,13 @@ namespace Tensorflow.Gradients | |||
/// </summary> | |||
public class math_grad | |||
{ | |||
public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad) | |||
public static Tensor[] _AddGrad(Operation op, Tensor[] grads) | |||
{ | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
var grad = grads[0]; | |||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
return (grad, grad); | |||
return new Tensor[] { grad, grad }; | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
@@ -24,21 +25,22 @@ namespace Tensorflow.Gradients | |||
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); | |||
var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy); | |||
return (r1, r2); | |||
return new Tensor[] { r1, r2 }; | |||
} | |||
public static Tensor _IdGrad(Operation op, Tensor grad) | |||
public static Tensor[] _IdGrad(Operation op, Tensor[] grads) | |||
{ | |||
return grad; | |||
return new Tensor[] { grads[0] }; | |||
} | |||
public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad) | |||
public static Tensor[] _MulGrad(Operation op, Tensor[] grads) | |||
{ | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
var grad = grads[0]; | |||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) && | |||
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) | |||
return (gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)); | |||
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) }; | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
@@ -54,11 +56,12 @@ namespace Tensorflow.Gradients | |||
var reshape1 = gen_array_ops.reshape(reduce_sum1, sx); | |||
var reshape2 = gen_array_ops.reshape(reduce_sum2, sy); | |||
return (reshape1, reshape2); | |||
return new Tensor[] { reshape1, reshape2 }; | |||
} | |||
public static (Tensor, Tensor) _MatMulGrad(Operation op, Tensor grad) | |||
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
Tensor grad_a = null, grad_b = null; | |||
var t_a = (bool)op.get_attr("transpose_a"); | |||
@@ -86,12 +89,13 @@ namespace Tensorflow.Gradients | |||
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true, transpose_b: true); | |||
} | |||
return (grad_a, grad_b); | |||
return new Tensor[] { grad_a, grad_b }; | |||
} | |||
public static (Tensor, Tensor) _MeanGrad(Operation op, Tensor grad) | |||
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) | |||
{ | |||
var sum_grad = _SumGrad(op, grad).Item1; | |||
var grad = grads[0]; | |||
var sum_grad = _SumGrad(op, grads)[0]; | |||
var input_shape = op.inputs[0]._shape_tuple(); | |||
var output_shape = op.outputs[0]._shape_tuple(); | |||
@@ -99,7 +103,7 @@ namespace Tensorflow.Gradients | |||
var output_shape_tensor = array_ops.shape(op.outputs[0]); | |||
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | |||
return (math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null); | |||
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; | |||
} | |||
private static Tensor _safe_shape_div(Tensor x, Tensor y) | |||
@@ -107,12 +111,13 @@ namespace Tensorflow.Gradients | |||
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); | |||
} | |||
public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad) | |||
public static Tensor[] _SubGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
return (grad, -grad); | |||
return new Tensor[] { grad, -grad }; | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
@@ -121,7 +126,7 @@ namespace Tensorflow.Gradients | |||
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); | |||
var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy); | |||
return (r1, r2); | |||
return new Tensor[] { r1, r2 }; | |||
} | |||
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | |||
@@ -129,8 +134,9 @@ namespace Tensorflow.Gradients | |||
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1; | |||
} | |||
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | |||
public static Tensor[] _SumGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var input_0_shape = op.inputs[0]._shape_tuple(); | |||
Tensor input_shape = null; | |||
@@ -145,7 +151,7 @@ namespace Tensorflow.Gradients | |||
input_shape = constant_op.constant(input_0_shape); | |||
else | |||
input_shape = array_ops.shape(op.inputs[0]); | |||
return (gen_array_ops.tile(grad, input_shape), null); | |||
return new Tensor[] { gen_array_ops.tile(grad, input_shape), null }; | |||
} | |||
} | |||
@@ -155,11 +161,12 @@ namespace Tensorflow.Gradients | |||
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||
grad = gen_array_ops.reshape(grad, output_shape_kept_dims); | |||
return (gen_array_ops.tile(grad, tile_scaling), null); | |||
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; | |||
} | |||
public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad) | |||
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
@@ -177,11 +184,12 @@ namespace Tensorflow.Gradients | |||
var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx); | |||
var reshape2 = gen_array_ops.reshape(reduce_sum2, sx); | |||
return (reshape2, reshape1); | |||
return new Tensor[] { reshape2, reshape1 }; | |||
} | |||
public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) | |||
public static Tensor[] _PowGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
var z = op.outputs[0]; | |||
@@ -212,7 +220,7 @@ namespace Tensorflow.Gradients | |||
var reduce_sum1 = math_ops.reduce_sum(mul1, ry); | |||
var gy = gen_array_ops.reshape(reduce_sum1, sy); | |||
return (gx, gy); | |||
return new Tensor[] { gx, gy }; | |||
} | |||
} | |||
} |
@@ -13,16 +13,32 @@ namespace Tensorflow.Gradients | |||
/// <param name="op"></param> | |||
/// <param name="grad"></param> | |||
/// <returns></returns> | |||
public static (Tensor, Tensor) _BiasAddGrad(Operation op, Tensor grad) | |||
public static Tensor[] _BiasAddGrad(Operation op, Tensor grad) | |||
{ | |||
string data_format = op.get_attr("data_format")?.ToString(); | |||
var bias_add_grad = gen_nn_ops.bias_add_grad(out_backprop: grad, data_format: data_format); | |||
return (grad, bias_add_grad); | |||
return new Tensor[] { grad, bias_add_grad }; | |||
} | |||
public static (Tensor, Tensor) _ReluGrad(Operation op, Tensor grad) | |||
public static Tensor[] _ReluGrad(Operation op, Tensor grad) | |||
{ | |||
return (gen_nn_ops.relu_grad(grad, op.outputs[0]), null); | |||
return new Tensor[] { gen_nn_ops.relu_grad(grad, op.outputs[0]) }; | |||
} | |||
/// <summary> | |||
/// Gradient function for SoftmaxCrossEntropyWithLogits. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="grad_loss"></param> | |||
/// <param name="grad_grad"></param> | |||
/// <returns></returns> | |||
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad_loss = grads[0]; | |||
var grad_grad = grads[1]; | |||
var softmax_grad = op.outputs[1]; | |||
throw new NotImplementedException("_SoftmaxCrossEntropyWithLogitsGrad"); | |||
} | |||
} | |||
} |
@@ -0,0 +1,47 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Gradients; | |||
namespace Tensorflow | |||
{ | |||
public partial class ops | |||
{ | |||
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) | |||
{ | |||
if (op.inputs == null) return null; | |||
// map tensorflow\python\ops\math_grad.py | |||
return (oper, out_grads) => | |||
{ | |||
Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); | |||
switch (oper.type) | |||
{ | |||
case "Add": | |||
return math_grad._AddGrad(oper, out_grads); | |||
case "Identity": | |||
return math_grad._IdGrad(oper, out_grads); | |||
case "Mul": | |||
return math_grad._MulGrad(oper, out_grads); | |||
case "Mean": | |||
return math_grad._MeanGrad(oper, out_grads); | |||
case "Sum": | |||
return math_grad._SumGrad(oper, out_grads); | |||
case "Sub": | |||
return math_grad._SubGrad(oper, out_grads); | |||
case "Pow": | |||
return math_grad._PowGrad(oper, out_grads); | |||
case "RealDiv": | |||
return math_grad._RealDivGrad(oper, out_grads); | |||
case "Reshape": | |||
return array_grad._ReshapeGrad(oper, out_grads); | |||
case "SoftmaxCrossEntropyWithLogits": | |||
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); | |||
default: | |||
throw new NotImplementedException($"get_gradient_function {oper.type}"); | |||
} | |||
}; | |||
} | |||
} | |||
} |
@@ -346,50 +346,6 @@ namespace Tensorflow | |||
session.run(operation, feed_dict); | |||
} | |||
public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op) | |||
{ | |||
if (op.inputs == null) return null; | |||
// map tensorflow\python\ops\math_grad.py | |||
return (oper, out_grads) => | |||
{ | |||
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); | |||
switch (oper.type) | |||
{ | |||
case "Add": | |||
var add = math_grad._AddGrad(oper, out_grads); | |||
return new Tensor[] { add.Item1, add.Item2 }; | |||
case "Identity": | |||
var id = math_grad._IdGrad(oper, out_grads); | |||
return new Tensor[] { id }; | |||
case "Mul": | |||
var mul = math_grad._MulGrad(oper, out_grads); | |||
return new Tensor[] { mul.Item1, mul.Item2 }; | |||
case "Mean": | |||
var mean = math_grad._MeanGrad(oper, out_grads); | |||
return new Tensor[] { mean.Item1, mean.Item2 }; | |||
case "Sum": | |||
var sum = math_grad._SumGrad(oper, out_grads); | |||
return new Tensor[] { sum.Item1, sum.Item2 }; | |||
case "Sub": | |||
var sub = math_grad._SubGrad(oper, out_grads); | |||
return new Tensor[] { sub.Item1, sub.Item2 }; | |||
case "Pow": | |||
var pow = math_grad._PowGrad(oper, out_grads); | |||
return new Tensor[] { pow.Item1, pow.Item2 }; | |||
case "RealDiv": | |||
var realdiv = math_grad._RealDivGrad(oper, out_grads); | |||
return new Tensor[] { realdiv.Item1, realdiv.Item2 }; | |||
case "Reshape": | |||
var reshape = array_grad._ReshapeGrad(oper, out_grads); | |||
return new Tensor[] { reshape.Item1, reshape.Item2 }; | |||
default: | |||
throw new NotImplementedException($"get_gradient_function {oper.type}"); | |||
} | |||
}; | |||
} | |||
public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||
{ | |||
return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name); | |||