From cf8465e6b71a61abed93c0f3f9ca609a8aebe6dd Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 22 Dec 2019 08:32:18 -0600 Subject: [PATCH] _SliceGrad #463 --- .../Gradients/array_grad.cs | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index b7ca5cf9..33c5f7c5 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -237,6 +237,27 @@ namespace Tensorflow.Gradients return new Tensor[] { null, array_ops.concat(list(grads), op.inputs[0]) }; } + [RegisterGradient("Slice")] + public static Tensor[] _SliceGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_vec = op.inputs[0]; + var begin_vec = op.inputs[1]; + var input_rank = array_ops.rank(input_vec); + var slice_size = array_ops.shape(op.outputs[0]); + + var shape = array_ops.stack(new Tensor[] { input_rank, new Tensor(1) }); + var before_pad = array_ops.reshape(begin_vec, shape); + var after_pad = array_ops.reshape(array_ops.shape(input_vec) - slice_size - begin_vec, shape); + var paddings = array_ops.concat(new Tensor[] { before_pad, after_pad }, 1); + return new Tensor[] + { + array_ops.pad(grad, paddings), + null, + null + }; + } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) {