Arnav Das 6 years ago
parent
commit
a406dc21b7
5 changed files with 161 additions and 30 deletions
  1. +146
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +0
    -30
      src/TensorFlowNET.Core/Gradients/array_grad.py.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  5. +10
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs

+ 146
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -0,0 +1,146 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Python;

namespace Tensorflow.Gradients
{
/// <summary>
/// tensorflow\python\ops\array_grad.py
/// </summary>
public class array_grad
{
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
{
var grad = grads[0];
return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1);
}

/// <summary>
/// Gradient for concat op.
/// </summary>
/// <param name="op">An operation.</param>
/// <param name="grad">
/// `Tensor` or `IndexedSlices` representing the gradients with respect
/// to each output of the op.
/// </param>
/// <param name="start_value_index">An integer index of the first value in the op.inputs.</param>
/// <param name="end_value_index">An integer index of the last value in the op.inputs.</param>
/// <param name="dim_index">An interger index of concat_dim or axis parameter in op.inputs.</param>
/// <returns>
/// Tensors representing the partial gradients with respect to each input
/// of the op.
/// </returns>
private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_value_index, int end_value_index, int dim_index)
{
// Degenerate concatenation, just return grad.
if (len(op.inputs) == 2)
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };

var concat_dim = op.inputs[dim_index];
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();

var out_grads = new List<Tensor>();
if (constant_op.is_constant(concat_dim))
{
/*If concat_dim is a constant defined in a different context,
then we duplicate it in the current context to avoid passing it
through an Enter node.
This is a small optimization in general, but it is required when
compiling with XLA, as XLA needs the concat input to be folded into a
constant.*/
var grad_context = control_flow_util.GetOutputContext(grad.op);
var dim_context = control_flow_util.GetOutputContext(concat_dim.op);
if (dim_context != grad_context)
{
var value = tensor_util.constant_value(concat_dim);
concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype);
}
}

// Using mod here for convenience since concat_dim is already verified
// in concat implementation to be within the allowed [-rank, rank) range.
var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]);

// Get the inputs' tensor shapes
var sizes = _ExtractInputShapes(input_values);

/* The magic number of 16 was found through benchmarking a range of sizes
on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
cases when switching implementations at N=16, but it is possible that
there will be a small number of performance regressions.*/
if (len(sizes) > 16)
{
// extract the size of each input along the concat dimension
var slice = array_ops.slice(array_ops.stack(sizes, axis: 1),
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
}
else
{
var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
foreach (var (begin, size) in zip(offset, sizes))
out_grads.Add(gen_ops.slice(grad, begin, size));
}

return (end_value_index <= dim_index ?
out_grads.ToArray().Concat(null) :
new Tensor[] { null }.Concat(out_grads)).ToArray();
}

/// <summary>
/// Extract the shapes of a set of input tensors.
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
{
var sizes = new Tensor[inputs.Length];
bool fully_known = true;
for(int i = 0; i < inputs.Length; i++)
{
var x = inputs[i];

var input_shape = array_ops.shape(x);
if (!(input_shape is Tensor) || input_shape.op.type != "Const")
{
fully_known = false;
break;
}

sizes[i] = input_shape;
}

if (fully_known)
return sizes;
else
return gen_ops.shape_n(inputs);
}


public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}

public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
}

private static Tensor _ReshapeToInput(Operation op, Tensor grad)
{
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
}

public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{
var p = op.inputs[1];
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
}
}
}

+ 0
- 30
src/TensorFlowNET.Core/Gradients/array_grad.py.cs View File

@@ -1,30 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
public class array_grad
{
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}

public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
}

private static Tensor _ReshapeToInput(Operation op, Tensor grad)
{
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
}

public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{
var p = op.inputs[1];
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
}
}
}

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs → src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -6,6 +6,9 @@ using Tensorflow.Operations;

namespace Tensorflow.Gradients
{
/// <summary>
///
/// </summary>
public class nn_grad
{
/// <summary>

+ 2
- 0
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -22,6 +22,8 @@ namespace Tensorflow
return math_grad._AddGrad(oper, out_grads);
case "BiasAdd":
return nn_grad._BiasAddGrad(oper, out_grads);
case "ConcatV2":
return array_grad._ConcatGradV2(oper, out_grads);
case "Exp":
return math_grad._ExpGrad(oper, out_grads);
case "Identity":


+ 10
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -88,5 +88,15 @@ namespace Tensorflow

return constant_op.constant(s_list, name: name);
}

public static bool is_constant(ITensorOrOperation tensor_or_op)
{
if (tensor_or_op is Tensor tensor)
return tensor.op.type == "Const";
else if (tensor_or_op is Operation op)
return op.type == "Const";
else
throw new ValueError("is_constant");
}
}
}

Loading…
Cancel
Save