| @@ -0,0 +1,31 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Bayesian Layer. | |||
| The high-level components(Cells) used to construct the bayesian neural network. | |||
| """ | |||
| from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper | |||
| from .conv_variational import ConvReparam | |||
| from .dense_variational import DenseReparam | |||
| from .layer_distribution import NormalPrior, NormalPosterior | |||
| from .bnn_cell_wrapper import WithBNNLossCell | |||
| __all__ = [] | |||
| __all__.extend(conv_variational.__all__) | |||
| __all__.extend(dense_variational.__all__) | |||
| __all__.extend(layer_distribution.__all__) | |||
| __all__.extend(bnn_cell_wrapper.__all__) | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate WithLossCell suitable for BNN.""" | |||
| from .conv_variational import _ConvVariational | |||
| from .dense_variational import _DenseVariational | |||
| from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss | |||
| __all__ = ['WithBNNLossCell'] | |||
| class ClassWrap: | |||
| """Decorator of WithBNNLossCell""" | |||
| def __init__(self, cls): | |||
| self._cls = cls | |||
| self.bnn_loss_file = None | |||
| def __call__(self, backbone, loss_fn, backbone_factor, kl_factor): | |||
| obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor) | |||
| bnn_with_loss = obj() | |||
| self.bnn_loss_file = obj.bnn_loss_file | |||
| return bnn_with_loss | |||
| @ClassWrap | |||
| class WithBNNLossCell: | |||
| r""" | |||
| Generate WithLossCell suitable for BNN. | |||
| Args: | |||
| backbone (Cell): The target network. | |||
| loss_fn (Cell): The loss function used to compute loss. | |||
| dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1. | |||
| bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1. | |||
| Inputs: | |||
| - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| Outputs: | |||
| Tensor, a scalar tensor with shape :math:`()`. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| >>> net_with_criterion_object = WithBNNLossCell(net, loss_fn) | |||
| >>> net_with_criterion = net_with_criterion_object() | |||
| >>> | |||
| >>> batch_size = 2 | |||
| >>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01) | |||
| >>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32)) | |||
| >>> | |||
| >>> net_with_criterion(data, label) | |||
| """ | |||
| def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | |||
| self.backbone = backbone | |||
| self.loss_fn = loss_fn | |||
| self.dnn_factor = dnn_factor | |||
| self.bnn_factor = bnn_factor | |||
| self.bnn_loss_file = None | |||
| def _generate_loss_cell(self): | |||
| """Generate WithBNNLossCell by ast.""" | |||
| layer_count = self._kl_loss_count(self.backbone) | |||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||
| self.dnn_factor, self.bnn_factor) | |||
| return bnn_with_loss | |||
| def _kl_loss_count(self, net): | |||
| """ Calculate the number of Bayesian layers.""" | |||
| count = 0 | |||
| for (_, layer) in net.name_cells().items(): | |||
| if isinstance(layer, (_DenseVariational, _ConvVariational)): | |||
| count += 1 | |||
| else: | |||
| count += self._kl_loss_count(layer) | |||
| return count | |||
| def __call__(self): | |||
| return self._generate_loss_cell() | |||
| @@ -0,0 +1,270 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Convolutional variational layers.""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import twice | |||
| from ...layer.conv import _Conv | |||
| from ...cell import Cell | |||
| from .layer_distribution import NormalPrior, NormalPosterior | |||
| __all__ = ['ConvReparam'] | |||
| class _ConvVariational(_Conv): | |||
| """ | |||
| Base class for all convolutional variational layers. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_prior_fn=NormalPrior, | |||
| weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), | |||
| bias_prior_fn=NormalPrior, | |||
| bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): | |||
| kernel_size = twice(kernel_size) | |||
| stride = twice(stride) | |||
| dilation = twice(dilation) | |||
| super(_ConvVariational, self).__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| pad_mode, | |||
| padding, | |||
| dilation, | |||
| group, | |||
| has_bias, | |||
| weight_init='normal', | |||
| bias_init='zeros' | |||
| ) | |||
| if pad_mode not in ('valid', 'same', 'pad'): | |||
| raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' | |||
| + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') | |||
| # convolution args | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.pad_mode = pad_mode | |||
| self.padding = padding | |||
| self.dilation = dilation | |||
| self.group = group | |||
| self.has_bias = has_bias | |||
| # distribution trainable parameters | |||
| self.shape = [self.out_channels, | |||
| self.in_channels // self.group, *self.kernel_size] | |||
| self.weight.requires_grad = False | |||
| if isinstance(weight_prior_fn, Cell): | |||
| self.weight_prior = weight_prior_fn | |||
| else: | |||
| self.weight_prior = weight_prior_fn() | |||
| self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') | |||
| if self.has_bias: | |||
| self.bias.requires_grad = False | |||
| if isinstance(bias_prior_fn, Cell): | |||
| self.bias_prior = bias_prior_fn | |||
| else: | |||
| self.bias_prior = bias_prior_fn() | |||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | |||
| # mindspore operations | |||
| self.bias_add = P.BiasAdd() | |||
| self.conv2d = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| self.log = P.Log() | |||
| self.sum = P.ReduceSum() | |||
| def construct(self, inputs): | |||
| outputs = self._apply_variational_weight(inputs) | |||
| if self.has_bias: | |||
| outputs = self._apply_variational_bias(outputs) | |||
| return outputs | |||
| def extend_repr(self): | |||
| str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \ | |||
| 'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\ | |||
| .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, | |||
| self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std, | |||
| self.has_bias) | |||
| if self.has_bias: | |||
| str_info = str_info + ', bias_mean={}, bias_std={}'\ | |||
| .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std) | |||
| return str_info | |||
| def _apply_variational_bias(self, inputs): | |||
| bias_posterior_tensor = self.bias_posterior("sample") | |||
| return self.bias_add(inputs, bias_posterior_tensor) | |||
| def compute_kl_loss(self): | |||
| """Compute kl loss""" | |||
| weight_post_mean = self.weight_posterior("mean") | |||
| weight_post_sd = self.weight_posterior("sd") | |||
| kl = self.weight_prior("kl_loss", "Normal", | |||
| weight_post_mean, weight_post_sd) | |||
| kl_loss = self.sum(kl) | |||
| if self.has_bias: | |||
| bias_post_mean = self.bias_posterior("mean") | |||
| bias_post_sd = self.bias_posterior("sd") | |||
| kl = self.bias_prior("kl_loss", "Normal", | |||
| bias_post_mean, bias_post_sd) | |||
| kl = self.sum(kl) | |||
| kl_loss += kl | |||
| return kl_loss | |||
| class ConvReparam(_ConvVariational): | |||
| r""" | |||
| Convolutional variational layers with Reparameterization. | |||
| See more details in paper `Auto-Encoding Variational Bayes | |||
| <https://arxiv.org/abs/1312.6114>` | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple[int]]): The data type is int or | |||
| tuple with 2 integers. Specifies the height and width of the 2D | |||
| convolution window. Single int means the value if for both | |||
| height and width of the kernel. A tuple of 2 ints means the | |||
| first value is for the height and the other is for the width of | |||
| the kernel. | |||
| stride(Union[int, tuple[int]]): The distance of kernel moving, | |||
| an int number that represents the height and width of movement | |||
| are both strides, or a tuple of two int numbers that represent | |||
| height and width of movement respectively. Default: 1. | |||
| pad_mode (str): Specifies padding mode. The optional values are | |||
| "same", "valid", "pad". Default: "same". | |||
| - same: Adopts the way of completion. Output height and width | |||
| will be the same as the input. | |||
| Total number of padding will be calculated for horizontal and | |||
| vertical direction and evenly distributed to top and bottom, | |||
| left and right if possible. Otherwise, the last extra padding | |||
| will be done from the bottom and the right side. If this mode | |||
| is set, `padding` must be 0. | |||
| - valid: Adopts the way of discarding. The possibly largest | |||
| height and width of output will be return without padding. | |||
| Extra pixels will be discarded. If this mode is set, `padding` | |||
| must be 0. | |||
| - pad: Implicit paddings on both sides of the input. The number | |||
| of `padding` will be padded to the input Tensor borders. | |||
| `padding` should be greater than or equal to 0. | |||
| padding (Union[int, tuple[int]]): Implicit paddings on both sides of | |||
| the input. Default: 0. | |||
| dilation (Union[int, tuple[int]]): The data type is int or tuple | |||
| with 2 integers. Specifies the dilation rate to use for dilated | |||
| convolution. If set to be :math:`k > 1`, | |||
| there will be :math:`k - 1` pixels skipped for each sampling | |||
| location. Its value should be greater or equal to 1 and bounded | |||
| by the height and width of the input. Default: 1. | |||
| group (int): Split filter into groups, `in_ channels` and | |||
| `out_channels` should be divisible by the number of groups. | |||
| Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. | |||
| Default: False. | |||
| weight_prior_fn: prior distribution for convolution kernel. | |||
| It should return a mindspore distribution instance. | |||
| Default: NormalPrior. (which creates an instance of standard | |||
| normal distribution). | |||
| weight_posterior_fn: posterior distribution for sampling convolution | |||
| kernel. It should be a function handle which returns a mindspore | |||
| distribution instance. | |||
| Default: NormalPosterior. | |||
| bias_prior_fn: prior distribution for bias vector. It should return | |||
| a mindspore distribution. | |||
| Default: NormalPrior(which creates an instance of standard | |||
| normal distribution). | |||
| bias_posterior_fn: posterior distribution for sampling bias vector. | |||
| It should be a function handle which returns a mindspore | |||
| distribution instance. | |||
| Default: NormalPosterior. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| Examples: | |||
| >>> net = ConvReparam(120, 240, 4, has_bias=False) | |||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||
| >>> net(input).shape | |||
| (1, 240, 1024, 640) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_prior_fn=NormalPrior, | |||
| weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), | |||
| bias_prior_fn=NormalPrior, | |||
| bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): | |||
| super(ConvReparam, self).__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| group=group, | |||
| has_bias=has_bias, | |||
| weight_prior_fn=weight_prior_fn, | |||
| weight_posterior_fn=weight_posterior_fn, | |||
| bias_prior_fn=bias_prior_fn, | |||
| bias_posterior_fn=bias_posterior_fn | |||
| ) | |||
| def _apply_variational_weight(self, inputs): | |||
| weight_posterior_tensor = self.weight_posterior("sample") | |||
| outputs = self.conv2d(inputs, weight_posterior_tensor) | |||
| return outputs | |||
| @@ -0,0 +1,188 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """dense_variational""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import check_int_positive, check_bool | |||
| from ...cell import Cell | |||
| from ...layer.activation import get_activation | |||
| from .layer_distribution import NormalPrior, NormalPosterior | |||
| __all__ = ['DenseReparam'] | |||
| class _DenseVariational(Cell): | |||
| """ | |||
| Base class for all dense variational layers. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| activation=None, | |||
| has_bias=True, | |||
| weight_prior_fn=NormalPrior, | |||
| weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), | |||
| bias_prior_fn=NormalPrior, | |||
| bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): | |||
| super(_DenseVariational, self).__init__() | |||
| self.in_channels = check_int_positive(in_channels) | |||
| self.out_channels = check_int_positive(out_channels) | |||
| self.has_bias = check_bool(has_bias) | |||
| if isinstance(weight_prior_fn, Cell): | |||
| self.weight_prior = weight_prior_fn | |||
| else: | |||
| self.weight_prior = weight_prior_fn() | |||
| self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') | |||
| if self.has_bias: | |||
| if isinstance(bias_prior_fn, Cell): | |||
| self.bias_prior = bias_prior_fn | |||
| else: | |||
| self.bias_prior = bias_prior_fn() | |||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | |||
| self.activation = activation | |||
| if isinstance(self.activation, str): | |||
| self.activation = get_activation(activation) | |||
| self.activation_flag = self.activation is not None | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.bias_add = P.BiasAdd() | |||
| self.sum = P.ReduceSum() | |||
| def construct(self, x): | |||
| outputs = self._apply_variational_weight(x) | |||
| if self.has_bias: | |||
| outputs = self._apply_variational_bias(outputs) | |||
| if self.activation_flag: | |||
| outputs = self.activation(outputs) | |||
| return outputs | |||
| def extend_repr(self): | |||
| str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \ | |||
| .format(self.in_channels, self.out_channels, self.weight_posterior.mean, | |||
| self.weight_posterior.untransformed_std, self.has_bias) | |||
| if self.has_bias: | |||
| str_info = str_info + ', bias_mean={}, bias_std={}' \ | |||
| .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std) | |||
| if self.activation_flag: | |||
| str_info = str_info + ', activation={}'.format(self.activation) | |||
| return str_info | |||
| def _apply_variational_bias(self, inputs): | |||
| bias_posterior_tensor = self.bias_posterior("sample") | |||
| return self.bias_add(inputs, bias_posterior_tensor) | |||
| def compute_kl_loss(self): | |||
| """Compute kl loss.""" | |||
| weight_post_mean = self.weight_posterior("mean") | |||
| weight_post_sd = self.weight_posterior("sd") | |||
| kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd) | |||
| kl_loss = self.sum(kl) | |||
| if self.has_bias: | |||
| bias_post_mean = self.bias_posterior("mean") | |||
| bias_post_sd = self.bias_posterior("sd") | |||
| kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd) | |||
| kl = self.sum(kl) | |||
| kl_loss += kl | |||
| return kl_loss | |||
| class DenseReparam(_DenseVariational): | |||
| r""" | |||
| Dense variational layers with Reparameterization. | |||
| See more details in paper `Auto-Encoding Variational Bayes | |||
| <https://arxiv.org/abs/1312.6114>` | |||
| Applies dense-connected layer for the input. This layer implements the operation as: | |||
| .. math:: | |||
| \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), | |||
| where :math:`\text{activation}` is the activation function passed as the activation | |||
| argument (if passed in), :math:`\text{activation}` is a weight matrix with the same | |||
| data type as the inputs created by the layer, :math:`\text{weight}` is a weight | |||
| matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a | |||
| bias vector with the same data type as the inputs created by the layer (only if | |||
| has_bias is True). The bias vector is sampling from posterior distribution of | |||
| :math:`\text{bias}`. | |||
| Args: | |||
| in_channels (int): The number of input channel. | |||
| out_channels (int): The number of output channel . | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| weight_prior_fn: prior distribution for weight. | |||
| It should return a mindspore distribution instance. | |||
| Default: NormalPrior. (which creates an instance of standard | |||
| normal distribution). | |||
| weight_posterior_fn: posterior distribution for sampling weight. | |||
| It should be a function handle which returns a mindspore | |||
| distribution instance. | |||
| Default: NormalPosterior. | |||
| bias_prior_fn: prior distribution for bias vector. It should return | |||
| a mindspore distribution. | |||
| Default: NormalPrior(which creates an instance of standard | |||
| normal distribution). | |||
| bias_posterior_fn: posterior distribution for sampling bias vector. | |||
| It should be a function handle which returns a mindspore | |||
| distribution instance. | |||
| Default: NormalPosterior. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = DenseReparam(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> net(input) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| activation=None, | |||
| has_bias=True, | |||
| weight_prior_fn=NormalPrior, | |||
| weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), | |||
| bias_prior_fn=NormalPrior, | |||
| bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): | |||
| super(DenseReparam, self).__init__( | |||
| in_channels, | |||
| out_channels, | |||
| activation=activation, | |||
| has_bias=has_bias, | |||
| weight_prior_fn=weight_prior_fn, | |||
| weight_posterior_fn=weight_posterior_fn, | |||
| bias_prior_fn=bias_prior_fn, | |||
| bias_posterior_fn=bias_posterior_fn | |||
| ) | |||
| def _apply_variational_weight(self, inputs): | |||
| weight_posterior_tensor = self.weight_posterior("sample") | |||
| outputs = self.matmul(inputs, weight_posterior_tensor) | |||
| return outputs | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Initialize normal distributions""" | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| from ...cell import Cell | |||
| from ..distribution.normal import Normal | |||
| __all__ = ['NormalPrior', 'NormalPosterior'] | |||
| class NormalPrior(Cell): | |||
| r""" | |||
| To initialize a normal distribution of mean 0 and standard deviation 0.1. | |||
| Args: | |||
| dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| Default: mindspore.float32. | |||
| mean (int, float): Mean of normal distribution. | |||
| std (int, float): Standard deviation of normal distribution. | |||
| Returns: | |||
| Cell, a normal distribution. | |||
| """ | |||
| def __init__(self, dtype=mstype.float32, mean=0, std=0.1): | |||
| super(NormalPrior, self).__init__() | |||
| self.normal = Normal(mean, std, dtype=dtype) | |||
| def construct(self, *inputs): | |||
| return self.normal(*inputs) | |||
| class NormalPosterior(Cell): | |||
| r""" | |||
| Build Normal distributions with trainable parameters. | |||
| Args: | |||
| name (str): Name prepended to trainable parameter. | |||
| shape (list): Shape of the mean and standard deviation. | |||
| dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| Default: mindspore.float32. | |||
| loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0. | |||
| loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters. | |||
| Default: 0.1. | |||
| untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable | |||
| parameters. Default: -5. | |||
| untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize | |||
| trainable parameters. Default: 0.1. | |||
| Returns: | |||
| Cell, a normal distribution. | |||
| """ | |||
| def __init__(self, | |||
| name, | |||
| shape, | |||
| dtype=mstype.float32, | |||
| loc_mean=0, | |||
| loc_std=0.1, | |||
| untransformed_scale_mean=-5, | |||
| untransformed_scale_std=0.1): | |||
| super(NormalPosterior, self).__init__() | |||
| if not isinstance(name, str): | |||
| raise ValueError('The type of `name` should be `str`') | |||
| self.mean = Parameter( | |||
| Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') | |||
| self.untransformed_std = Parameter( | |||
| Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype), | |||
| name=name + '_untransformed_std') | |||
| self.normal = Normal() | |||
| def std_trans(self, std_pre): | |||
| """Transform std_pre to prevent its value being zero.""" | |||
| std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1) | |||
| return std | |||
| def construct(self, *inputs): | |||
| std = self.std_trans(self.untransformed_std) | |||
| return self.normal(*inputs, mean=self.mean, sd=std) | |||
| @@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability as msp | |||
| def cast_to_tensor(t, hint_dtype=mstype.float32): | |||
| """ | |||
| @@ -84,7 +85,7 @@ def check_scalar_from_param(params): | |||
| Notes: String parameters are excluded. | |||
| """ | |||
| for value in params.values(): | |||
| if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)): | |||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||
| return params['distribution'].is_scalar_batch | |||
| if isinstance(value, Parameter): | |||
| return False | |||
| @@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params): | |||
| """ | |||
| broadcast_shape = [] | |||
| for value in params.values(): | |||
| if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)): | |||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||
| return params['distribution'].broadcast_shape | |||
| if isinstance(value, (str, type(params['dtype']))): | |||
| continue | |||
| @@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer): | |||
| def visit_FunctionDef(self, node): | |||
| """visit function and add kl_loss computation.""" | |||
| self.generic_visit(node) | |||
| if node.name == 'compute_kl_loss': | |||
| if node.name == 'cal_kl_loss': | |||
| for i in range(self.layer_count): | |||
| func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], | |||
| value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), | |||
| @@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor): | |||
| layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. | |||
| backbone (Cell): The target network to wrap. | |||
| loss_fn (Cell): The loss function used to compute loss. | |||
| dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. | |||
| dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function. | |||
| bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. | |||
| """ | |||
| bnn_loss_func = _generate_kl_loss_func(layer_count) | |||
| @@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test | |||
| sklearn >= 0.0 # for st test | |||
| pandas >= 1.0.2 # for ut test | |||
| bs4 | |||
| astunparse | |||
| @@ -92,7 +92,8 @@ required_package = [ | |||
| 'easydict >= 1.9', | |||
| 'sympy >= 1.4', | |||
| 'cffi >= 1.13.2', | |||
| 'decorator >= 4.4.0' | |||
| 'decorator >= 4.4.0', | |||
| 'astunparse >= 1.6.3' | |||
| ] | |||
| package_data = { | |||