You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DepthwiseConv2D.cs 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using System;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Common.Types;
  8. using Tensorflow.Keras.Utils;
  9. using Tensorflow.Operations;
  10. using Newtonsoft.Json;
  11. using System.Security.Cryptography;
  12. namespace Tensorflow.Keras.Layers
  13. {
  14. public class DepthwiseConv2DArgs: Conv2DArgs
  15. {
  16. /// <summary>
  17. /// depth_multiplier: The number of depthwise convolution output channels for
  18. /// each input channel.The total number of depthwise convolution output
  19. /// channels will be equal to `filters_in* depth_multiplier`.
  20. /// </summary>
  21. [JsonProperty("depth_multiplier")]
  22. public int DepthMultiplier { get; set; } = 1;
  23. [JsonProperty("depthwise_initializer")]
  24. public IInitializer DepthwiseInitializer { get; set; }
  25. }
  26. public class DepthwiseConv2D : Conv2D
  27. {
  28. /// <summary>
  29. /// depth_multiplier: The number of depthwise convolution output channels for
  30. /// each input channel.The total number of depthwise convolution output
  31. /// channels will be equal to `filters_in* depth_multiplier`.
  32. /// </summary>
  33. int DepthMultiplier = 1;
  34. IInitializer DepthwiseInitializer;
  35. int[] strides;
  36. int[] dilation_rate;
  37. string getDataFormat()
  38. {
  39. return data_format == "channels_first" ? "NCHW" : "NHWC";
  40. }
  41. static int _id = 1;
  42. public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args)
  43. {
  44. args.Padding = args.Padding.ToUpper();
  45. if(string.IsNullOrEmpty(args.Name))
  46. name = "DepthwiseConv2D_" + _id;
  47. this.DepthMultiplier = args.DepthMultiplier;
  48. this.DepthwiseInitializer = args.DepthwiseInitializer;
  49. }
  50. public override void build(KerasShapesWrapper input_shape)
  51. {
  52. //base.build(input_shape);
  53. var shape = input_shape.ToSingleShape();
  54. int channel_axis = data_format == "channels_first" ? 1 : -1;
  55. var input_channel = channel_axis < 0 ?
  56. shape.dims[shape.ndim + channel_axis] :
  57. shape.dims[channel_axis];
  58. var arg = args as DepthwiseConv2DArgs;
  59. if (arg.Strides.ndim != shape.ndim)
  60. {
  61. if (arg.Strides.ndim == 2)
  62. {
  63. this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 };
  64. }
  65. else
  66. {
  67. this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides");
  68. }
  69. }
  70. else
  71. {
  72. this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray();
  73. }
  74. if (arg.DilationRate.ndim != shape.ndim)
  75. {
  76. this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate");
  77. }
  78. long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1];
  79. var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] {
  80. channel_data,
  81. this.DepthMultiplier
  82. });
  83. this.kernel = this.add_weight(
  84. shape: depthwise_kernel_shape,
  85. initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer,
  86. name: "depthwise_kernel",
  87. trainable: true,
  88. dtype: DType,
  89. regularizer: this.kernel_regularizer
  90. );
  91. var axes = new Dictionary<int, int>();
  92. axes.Add(-1, (int)input_channel);
  93. inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);
  94. if (use_bias)
  95. {
  96. bias = add_weight(name: "bias",
  97. shape: ((int)channel_data),
  98. initializer: bias_initializer,
  99. trainable: true,
  100. dtype: DType);
  101. }
  102. built = true;
  103. _buildInputShape = input_shape;
  104. }
  105. protected override Tensors Call(Tensors inputs, Tensors state = null,
  106. bool? training = false, IOptionalArgs? optional_args = null)
  107. {
  108. Tensor outputs = null;
  109. outputs = gen_nn_ops.depthwise_conv2d_native(
  110. inputs,
  111. filter: this.kernel.AsTensor(),
  112. strides: this.strides,
  113. padding: this.padding,
  114. dilations: this.dilation_rate,
  115. data_format: this.getDataFormat(),
  116. name: name
  117. );
  118. if (use_bias)
  119. {
  120. if (data_format == "channels_first")
  121. {
  122. throw new NotImplementedException("call channels_first");
  123. }
  124. else
  125. {
  126. outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias),
  127. data_format: this.getDataFormat(), name: name);
  128. }
  129. }
  130. if (activation != null)
  131. outputs = activation.Apply(outputs);
  132. return outputs;
  133. }
  134. }
  135. }