Browse Source

_SliceGrad #463

tags/v0.20
Oceania2018 5 years ago
parent
commit
cf8465e6b7
1 changed files with 21 additions and 0 deletions
  1. +21
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs

+ 21
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -237,6 +237,27 @@ namespace Tensorflow.Gradients
return new Tensor[] { null, array_ops.concat(list(grads), op.inputs[0]) };
}

[RegisterGradient("Slice")]
public static Tensor[] _SliceGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_vec = op.inputs[0];
var begin_vec = op.inputs[1];
var input_rank = array_ops.rank(input_vec);
var slice_size = array_ops.shape(op.outputs[0]);

var shape = array_ops.stack(new Tensor[] { input_rank, new Tensor(1) });
var before_pad = array_ops.reshape(begin_vec, shape);
var after_pad = array_ops.reshape(array_ops.shape(input_vec) - slice_size - begin_vec, shape);
var paddings = array_ops.concat(new Tensor[] { before_pad, after_pad }, 1);
return new Tensor[]
{
array_ops.pad(grad, paddings),
null,
null
};
}

[RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{


Loading…
Cancel
Save