Browse Source

Pack/Unpack gradient. #847

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
824dfe6aaf
3 changed files with 19 additions and 13 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +1
    -9
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +2
    -4
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs

+ 16
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -223,6 +223,22 @@ namespace Tensorflow.Gradients
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}

[RegisterGradient("Pack")]
public static Tensor[] _PackGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var num = op.get_attr<int>("N");
var axis = op.get_attr<int>("axis");
return array_ops.unstack(grad, num: num, axis: axis);
}

[RegisterGradient("Unpack")]
public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads)
{
var axis = op.get_attr<int>("axis");
return new[] { array_ops.stack(grads, axis: axis) };
}

[RegisterGradient("Pad")]
public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
{


+ 1
- 9
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -494,20 +494,12 @@ namespace Tensorflow
return ops.convert_to_tensor(values, name: name);
}

var value_shape = ops.convert_to_tensor(values[0], name: name).shape;

return gen_array_ops.pack(values, axis: axis, name: name);
}

public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack")
{
if (num == null)
{
value = ops.convert_to_tensor(value);
var value_shape = value.shape;
num = (int)value_shape.dims[axis];
}

num = num ?? value.shape.as_int_list()[axis];
return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name);
}



+ 2
- 4
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -265,10 +265,8 @@ namespace Tensorflow
}

public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("Unpack", name, new { value, num, axis });
return _op.outputs;
}
=> tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num)
.SetAttributes(new { axis }));

public static Tensor where(Tensor condition, string name = null)
{


Loading…
Cancel
Save