|
- using Tensorflow.Keras.ArgsDefinition.Reshaping;
- using Tensorflow.Keras.Engine;
- using Tensorflow.Keras.Saving;
- using Tensorflow.Common.Types;
-
- namespace Tensorflow.Keras.Layers.Reshaping
- {
- /// <summary>
- /// Similar to copping 2D
- /// </summary>
- public class Cropping3D : Layer
- {
- Cropping3DArgs args;
- public Cropping3D(Cropping3DArgs args) : base(args)
- {
- this.args = args;
- }
-
- public override void build(KerasShapesWrapper input_shape)
- {
- built = true;
- _buildInputShape = input_shape;
- }
-
- <<<<<<< HEAD
- <<<<<<< HEAD
- protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
- =======
- protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
- >>>>>>> master
- =======
- protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
- >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8
- {
- Tensor output = inputs;
- if (output.rank != 5)
- {
- // throw an ValueError exception
- throw new ValueError("Expected dim=5, found dim=" + output.rank);
- }
-
- if (args.cropping.shape == new Shape(1))
- {
- int crop = args.cropping[0];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- output = output[new Slice(),
- new Slice(crop, (int)output.shape[1] - crop),
- new Slice(crop, (int)output.shape[2] - crop),
- new Slice(crop, (int)output.shape[3] - crop),
- new Slice()];
- }
- else
- {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop, (int)output.shape[2] - crop),
- new Slice(crop, (int)output.shape[3] - crop),
- new Slice(crop, (int)output.shape[4] - crop)];
- }
-
- }
- // int[1][3] equivalent to a tuple of 3 integers
- else if (args.cropping.shape == new Shape(3))
- {
- var crop_1 = args.cropping[0];
- var crop_2 = args.cropping[1];
- var crop_3 = args.cropping[2];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- output = output[new Slice(),
- new Slice(crop_1, (int)output.shape[1] - crop_1),
- new Slice(crop_2, (int)output.shape[2] - crop_2),
- new Slice(crop_3, (int)output.shape[3] - crop_3),
- new Slice()];
- }
- else
- {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop_1, (int)output.shape[2] - crop_1),
- new Slice(crop_2, (int)output.shape[3] - crop_2),
- new Slice(crop_3, (int)output.shape[4] - crop_3)];
- }
- }
- else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2)
- {
- int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
- int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
- int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- output = output[new Slice(),
- new Slice(x, (int)output.shape[1] - x_end),
- new Slice(y, (int)output.shape[2] - y_end),
- new Slice(z, (int)output.shape[3] - z_end),
- new Slice()];
- }
- else
- {
- output = output[new Slice(),
- new Slice(),
- new Slice(x, (int)output.shape[2] - x_end),
- new Slice(y, (int)output.shape[3] - y_end),
- new Slice(z, (int)output.shape[4] - z_end)
- ];
- }
- }
- return output;
- }
- public override Shape ComputeOutputShape(Shape input_shape)
- {
- if (args.cropping.shape == new Shape(1))
- {
- int crop = args.cropping[0];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]);
- }
- else
- {
- return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2);
- }
- }
- // int[1][3] equivalent to a tuple of 3 integers
- else if (args.cropping.shape == new Shape(3))
- {
- var crop_start_1 = args.cropping[0];
- var crop_start_2 = args.cropping[1];
- var crop_start_3 = args.cropping[2];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]);
- }
- else
- {
- return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2);
- }
- }
- else if (args.cropping.shape == new Shape(3, 2))
- {
- int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
- int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
- int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
- if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
- {
- return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]);
- }
- else
- {
- return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end);
- }
- }
- else
- {
- throw new ValueError();
- }
- }
- }
- }
|