@@ -223,6 +223,22 @@ namespace Tensorflow.Gradients | |||||
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | ||||
} | } | ||||
[RegisterGradient("Pack")] | |||||
public static Tensor[] _PackGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var num = op.get_attr<int>("N"); | |||||
var axis = op.get_attr<int>("axis"); | |||||
return array_ops.unstack(grad, num: num, axis: axis); | |||||
} | |||||
[RegisterGradient("Unpack")] | |||||
public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var axis = op.get_attr<int>("axis"); | |||||
return new[] { array_ops.stack(grads, axis: axis) }; | |||||
} | |||||
[RegisterGradient("Pad")] | [RegisterGradient("Pad")] | ||||
public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
@@ -494,20 +494,12 @@ namespace Tensorflow | |||||
return ops.convert_to_tensor(values, name: name); | return ops.convert_to_tensor(values, name: name); | ||||
} | } | ||||
var value_shape = ops.convert_to_tensor(values[0], name: name).shape; | |||||
return gen_array_ops.pack(values, axis: axis, name: name); | return gen_array_ops.pack(values, axis: axis, name: name); | ||||
} | } | ||||
public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") | public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") | ||||
{ | { | ||||
if (num == null) | |||||
{ | |||||
value = ops.convert_to_tensor(value); | |||||
var value_shape = value.shape; | |||||
num = (int)value_shape.dims[axis]; | |||||
} | |||||
num = num ?? value.shape.as_int_list()[axis]; | |||||
return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); | return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); | ||||
} | } | ||||
@@ -265,10 +265,8 @@ namespace Tensorflow | |||||
} | } | ||||
public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) | public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) | ||||
{ | |||||
var _op = tf.OpDefLib._apply_op_helper("Unpack", name, new { value, num, axis }); | |||||
return _op.outputs; | |||||
} | |||||
=> tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num) | |||||
.SetAttributes(new { axis })); | |||||
public static Tensor where(Tensor condition, string name = null) | public static Tensor where(Tensor condition, string name = null) | ||||
{ | { | ||||