You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping3D.cs 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. using Tensorflow.Keras.Saving;
  4. using Tensorflow.Common.Types;
  5. namespace Tensorflow.Keras.Layers.Reshaping
  6. {
  7. /// <summary>
  8. /// Similar to copping 2D
  9. /// </summary>
  10. public class Cropping3D : Layer
  11. {
  12. Cropping3DArgs args;
  13. public Cropping3D(Cropping3DArgs args) : base(args)
  14. {
  15. this.args = args;
  16. }
  17. public override void build(KerasShapesWrapper input_shape)
  18. {
  19. built = true;
  20. _buildInputShape = input_shape;
  21. }
  22. <<<<<<< HEAD
  23. <<<<<<< HEAD
  24. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  25. =======
  26. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  27. >>>>>>> master
  28. =======
  29. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  30. >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8
  31. {
  32. Tensor output = inputs;
  33. if (output.rank != 5)
  34. {
  35. // throw an ValueError exception
  36. throw new ValueError("Expected dim=5, found dim=" + output.rank);
  37. }
  38. if (args.cropping.shape == new Shape(1))
  39. {
  40. int crop = args.cropping[0];
  41. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  42. {
  43. output = output[new Slice(),
  44. new Slice(crop, (int)output.shape[1] - crop),
  45. new Slice(crop, (int)output.shape[2] - crop),
  46. new Slice(crop, (int)output.shape[3] - crop),
  47. new Slice()];
  48. }
  49. else
  50. {
  51. output = output[new Slice(),
  52. new Slice(),
  53. new Slice(crop, (int)output.shape[2] - crop),
  54. new Slice(crop, (int)output.shape[3] - crop),
  55. new Slice(crop, (int)output.shape[4] - crop)];
  56. }
  57. }
  58. // int[1][3] equivalent to a tuple of 3 integers
  59. else if (args.cropping.shape == new Shape(3))
  60. {
  61. var crop_1 = args.cropping[0];
  62. var crop_2 = args.cropping[1];
  63. var crop_3 = args.cropping[2];
  64. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  65. {
  66. output = output[new Slice(),
  67. new Slice(crop_1, (int)output.shape[1] - crop_1),
  68. new Slice(crop_2, (int)output.shape[2] - crop_2),
  69. new Slice(crop_3, (int)output.shape[3] - crop_3),
  70. new Slice()];
  71. }
  72. else
  73. {
  74. output = output[new Slice(),
  75. new Slice(),
  76. new Slice(crop_1, (int)output.shape[2] - crop_1),
  77. new Slice(crop_2, (int)output.shape[3] - crop_2),
  78. new Slice(crop_3, (int)output.shape[4] - crop_3)];
  79. }
  80. }
  81. else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2)
  82. {
  83. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  84. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  85. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  86. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  87. {
  88. output = output[new Slice(),
  89. new Slice(x, (int)output.shape[1] - x_end),
  90. new Slice(y, (int)output.shape[2] - y_end),
  91. new Slice(z, (int)output.shape[3] - z_end),
  92. new Slice()];
  93. }
  94. else
  95. {
  96. output = output[new Slice(),
  97. new Slice(),
  98. new Slice(x, (int)output.shape[2] - x_end),
  99. new Slice(y, (int)output.shape[3] - y_end),
  100. new Slice(z, (int)output.shape[4] - z_end)
  101. ];
  102. }
  103. }
  104. return output;
  105. }
  106. public override Shape ComputeOutputShape(Shape input_shape)
  107. {
  108. if (args.cropping.shape == new Shape(1))
  109. {
  110. int crop = args.cropping[0];
  111. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  112. {
  113. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]);
  114. }
  115. else
  116. {
  117. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2);
  118. }
  119. }
  120. // int[1][3] equivalent to a tuple of 3 integers
  121. else if (args.cropping.shape == new Shape(3))
  122. {
  123. var crop_start_1 = args.cropping[0];
  124. var crop_start_2 = args.cropping[1];
  125. var crop_start_3 = args.cropping[2];
  126. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  127. {
  128. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]);
  129. }
  130. else
  131. {
  132. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2);
  133. }
  134. }
  135. else if (args.cropping.shape == new Shape(3, 2))
  136. {
  137. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  138. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  139. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  140. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  141. {
  142. return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]);
  143. }
  144. else
  145. {
  146. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end);
  147. }
  148. }
  149. else
  150. {
  151. throw new ValueError();
  152. }
  153. }
  154. }
  155. }