using System;
using System.Collections.Generic;
using System.Text;
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations;
using Newtonsoft.Json;
using System.Security.Cryptography;
namespace Tensorflow.Keras.Layers
{
public class DepthwiseConv2DArgs: Conv2DArgs
{
///
/// depth_multiplier: The number of depthwise convolution output channels for
/// each input channel.The total number of depthwise convolution output
/// channels will be equal to `filters_in* depth_multiplier`.
///
[JsonProperty("depth_multiplier")]
public int DepthMultiplier { get; set; } = 1;
[JsonProperty("depthwise_initializer")]
public IInitializer DepthwiseInitializer { get; set; }
}
public class DepthwiseConv2D : Conv2D
{
///
/// depth_multiplier: The number of depthwise convolution output channels for
/// each input channel.The total number of depthwise convolution output
/// channels will be equal to `filters_in* depth_multiplier`.
///
int DepthMultiplier = 1;
IInitializer DepthwiseInitializer;
int[] strides;
int[] dilation_rate;
string getDataFormat()
{
return data_format == "channels_first" ? "NCHW" : "NHWC";
}
static int _id = 1;
public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args)
{
args.Padding = args.Padding.ToUpper();
if(string.IsNullOrEmpty(args.Name))
name = "DepthwiseConv2D_" + _id;
this.DepthMultiplier = args.DepthMultiplier;
this.DepthwiseInitializer = args.DepthwiseInitializer;
}
public override void build(KerasShapesWrapper input_shape)
{
//base.build(input_shape);
var shape = input_shape.ToSingleShape();
int channel_axis = data_format == "channels_first" ? 1 : -1;
var input_channel = channel_axis < 0 ?
shape.dims[shape.ndim + channel_axis] :
shape.dims[channel_axis];
var arg = args as DepthwiseConv2DArgs;
if (arg.Strides.ndim != shape.ndim)
{
if (arg.Strides.ndim == 2)
{
this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 };
}
else
{
this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides");
}
}
else
{
this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray();
}
if (arg.DilationRate.ndim != shape.ndim)
{
this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate");
}
long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1];
var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] {
channel_data,
this.DepthMultiplier
});
this.kernel = this.add_weight(
shape: depthwise_kernel_shape,
initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer,
name: "depthwise_kernel",
trainable: true,
dtype: DType,
regularizer: this.kernel_regularizer
);
var axes = new Dictionary();
axes.Add(-1, (int)input_channel);
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);
if (use_bias)
{
bias = add_weight(name: "bias",
shape: ((int)channel_data),
initializer: bias_initializer,
trainable: true,
dtype: DType);
}
built = true;
_buildInputShape = input_shape;
}
protected override Tensors Call(Tensors inputs, Tensors state = null,
bool? training = false, IOptionalArgs? optional_args = null)
{
Tensor outputs = null;
outputs = gen_nn_ops.depthwise_conv2d_native(
inputs,
filter: this.kernel.AsTensor(),
strides: this.strides,
padding: this.padding,
dilations: this.dilation_rate,
data_format: this.getDataFormat(),
name: name
);
if (use_bias)
{
if (data_format == "channels_first")
{
throw new NotImplementedException("call channels_first");
}
else
{
outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias),
data_format: this.getDataFormat(), name: name);
}
}
if (activation != null)
outputs = activation.Apply(outputs);
return outputs;
}
}
}