|
- using Tensorflow.Keras.ArgsDefinition;
- using Tensorflow.Keras.Engine;
-
- namespace Tensorflow.Keras.Layers {
- public class Cropping1D : Layer {
- CroppingArgs args;
- public Cropping1D ( CroppingArgs args ) : base(args) {
- this.args = args;
- }
-
- protected override void build ( Tensors inputs ) {
- if ( args.cropping.rank != 1 ) {
- // throw an ValueError exception
- throw new ValueError("");
- }
- else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) {
- throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
- }
- built = true;
- }
-
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- if ( output.rank != 3 ) {
- // throw an ValueError exception
- throw new ValueError("Expected dim=3, found dim=" + output.rank);
- }
- if ( args.cropping.shape[0] == 1 ) {
- int crop_start = args.cropping[0];
- output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()];
- }
- else {
- int crop_start = args.cropping[0], crop_end = args.cropping[1];
- output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()];
- }
- return output;
- }
-
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- if ( args.cropping.shape[0] == 1 ) {
- int crop = args.cropping[0];
- return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2]));
- }
- else {
- int crop_start = args.cropping[0], crop_end = args.cropping[1];
- return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2]));
- }
- }
- }
- }
|