using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Reshaping { /// /// Similar to copping 2D /// public class Cropping3D : Layer { Cropping3DArgs args; public Cropping3D(Cropping3DArgs args) : base(args) { this.args = args; } public override void build(KerasShapesWrapper input_shape) { built = true; _buildInputShape = input_shape; } <<<<<<< HEAD <<<<<<< HEAD protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) ======= protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) >>>>>>> master ======= protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8 { Tensor output = inputs; if (output.rank != 5) { // throw an ValueError exception throw new ValueError("Expected dim=5, found dim=" + output.rank); } if (args.cropping.shape == new Shape(1)) { int crop = args.cropping[0]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(crop, (int)output.shape[1] - crop), new Slice(crop, (int)output.shape[2] - crop), new Slice(crop, (int)output.shape[3] - crop), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(crop, (int)output.shape[2] - crop), new Slice(crop, (int)output.shape[3] - crop), new Slice(crop, (int)output.shape[4] - crop)]; } } // int[1][3] equivalent to a tuple of 3 integers else if (args.cropping.shape == new Shape(3)) { var crop_1 = args.cropping[0]; var crop_2 = args.cropping[1]; var crop_3 = args.cropping[2]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(crop_1, (int)output.shape[1] - crop_1), new Slice(crop_2, (int)output.shape[2] - crop_2), new Slice(crop_3, (int)output.shape[3] - crop_3), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(crop_1, (int)output.shape[2] - crop_1), new Slice(crop_2, (int)output.shape[3] - crop_2), new Slice(crop_3, (int)output.shape[4] - crop_3)]; } } else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2) { int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(x, (int)output.shape[1] - x_end), new Slice(y, (int)output.shape[2] - y_end), new Slice(z, (int)output.shape[3] - z_end), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(x, (int)output.shape[2] - x_end), new Slice(y, (int)output.shape[3] - y_end), new Slice(z, (int)output.shape[4] - z_end) ]; } } return output; } public override Shape ComputeOutputShape(Shape input_shape) { if (args.cropping.shape == new Shape(1)) { int crop = args.cropping[0]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2); } } // int[1][3] equivalent to a tuple of 3 integers else if (args.cropping.shape == new Shape(3)) { var crop_start_1 = args.cropping[0]; var crop_start_2 = args.cropping[1]; var crop_start_3 = args.cropping[2]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2); } } else if (args.cropping.shape == new Shape(3, 2)) { int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; if (args.data_format == Cropping3DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end); } } else { throw new ValueError(); } } } }