You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping1D.cs 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. using Tensorflow.Keras.ArgsDefinition;
  2. using Tensorflow.Keras.Engine;
  3. namespace Tensorflow.Keras.Layers {
  4. public class Cropping1D : Layer {
  5. CroppingArgs args;
  6. public Cropping1D ( CroppingArgs args ) : base(args) {
  7. this.args = args;
  8. }
  9. protected override void build ( Tensors inputs ) {
  10. if ( args.cropping.rank != 1 ) {
  11. // throw an ValueError exception
  12. throw new ValueError("");
  13. }
  14. else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) {
  15. throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
  16. }
  17. built = true;
  18. }
  19. protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
  20. Tensor output = inputs;
  21. if ( output.rank != 3 ) {
  22. // throw an ValueError exception
  23. throw new ValueError("Expected dim=3, found dim=" + output.rank);
  24. }
  25. if ( args.cropping.shape[0] == 1 ) {
  26. int crop_start = args.cropping[0];
  27. output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()];
  28. }
  29. else {
  30. int crop_start = args.cropping[0], crop_end = args.cropping[1];
  31. output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()];
  32. }
  33. return output;
  34. }
  35. public override Shape ComputeOutputShape ( Shape input_shape ) {
  36. if ( args.cropping.shape[0] == 1 ) {
  37. int crop = args.cropping[0];
  38. return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2]));
  39. }
  40. else {
  41. int crop_start = args.cropping[0], crop_end = args.cropping[1];
  42. return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2]));
  43. }
  44. }
  45. }
  46. }