From 824dfe6aaf58b0c5b3c34e02dfa5d8404bcf8a23 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 5 Sep 2021 15:17:37 -0500 Subject: [PATCH] Pack/Unpack gradient. #847 --- src/TensorFlowNET.Core/Gradients/array_grad.cs | 16 ++++++++++++++++ src/TensorFlowNET.Core/Operations/array_ops.cs | 10 +--------- .../Operations/gen_array_ops.cs | 6 ++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index f80f8ac6..528b5208 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -223,6 +223,22 @@ namespace Tensorflow.Gradients return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Pack")] + public static Tensor[] _PackGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var num = op.get_attr("N"); + var axis = op.get_attr("axis"); + return array_ops.unstack(grad, num: num, axis: axis); + } + + [RegisterGradient("Unpack")] + public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads) + { + var axis = op.get_attr("axis"); + return new[] { array_ops.stack(grads, axis: axis) }; + } + [RegisterGradient("Pad")] public static Tensor[] _PadGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index b0ef1f2d..d13e0005 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -494,20 +494,12 @@ namespace Tensorflow return ops.convert_to_tensor(values, name: name); } - var value_shape = ops.convert_to_tensor(values[0], name: name).shape; - return gen_array_ops.pack(values, axis: axis, name: name); } public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") { - if (num == null) - { - value = ops.convert_to_tensor(value); - var value_shape = value.shape; - num = (int)value_shape.dims[axis]; - } - + num = num ?? value.shape.as_int_list()[axis]; return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 65599a4c..dd1604f6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -265,10 +265,8 @@ namespace Tensorflow } public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Unpack", name, new { value, num, axis }); - return _op.outputs; - } + => tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num) + .SetAttributes(new { axis })); public static Tensor where(Tensor condition, string name = null) {