| @@ -1,291 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MobileNetV2 model define""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import TensorAdd | |||||
| from mindspore import Parameter, Tensor | |||||
| from mindspore.common.initializer import initializer | |||||
| __all__ = ['mobilenet_v2'] | |||||
| def _make_divisible(v, divisor, min_value=None): | |||||
| if min_value is None: | |||||
| min_value = divisor | |||||
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |||||
| # Make sure that round down does not go down by more than 10%. | |||||
| if new_v < 0.9 * v: | |||||
| new_v += divisor | |||||
| return new_v | |||||
| class GlobalAvgPooling(nn.Cell): | |||||
| """ | |||||
| Global avg pooling definition. | |||||
| Args: | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> GlobalAvgPooling() | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GlobalAvgPooling, self).__init__() | |||||
| self.mean = P.ReduceMean(keep_dims=False) | |||||
| def construct(self, x): | |||||
| x = self.mean(x, (2, 3)) | |||||
| return x | |||||
| class DepthwiseConv(nn.Cell): | |||||
| """ | |||||
| Depthwise Convolution warpper definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| pad_mode (str): pad mode in (pad, same, valid) | |||||
| channel_multiplier (int): Output channel multiplier | |||||
| has_bias (bool): has bias or not | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | |||||
| """ | |||||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | |||||
| super(DepthwiseConv, self).__init__() | |||||
| self.has_bias = has_bias | |||||
| self.in_channels = in_planes | |||||
| self.channel_multiplier = channel_multiplier | |||||
| self.out_channels = in_planes * channel_multiplier | |||||
| self.kernel_size = (kernel_size, kernel_size) | |||||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, | |||||
| kernel_size=self.kernel_size, | |||||
| stride=stride, pad_mode=pad_mode, pad=pad) | |||||
| self.bias_add = P.BiasAdd() | |||||
| weight_shape = [channel_multiplier, in_planes, *self.kernel_size] | |||||
| self.weight = Parameter(initializer('ones', weight_shape), name='weight') | |||||
| if has_bias: | |||||
| bias_shape = [channel_multiplier * in_planes] | |||||
| self.bias = Parameter(initializer('zeros', bias_shape), name='bias') | |||||
| else: | |||||
| self.bias = None | |||||
| def construct(self, x): | |||||
| output = self.depthwise_conv(x, self.weight) | |||||
| if self.has_bias: | |||||
| output = self.bias_add(output, self.bias) | |||||
| return output | |||||
| class ConvBNReLU(nn.Cell): | |||||
| """ | |||||
| Convolution/Depthwise fused with Batchnorm and ReLU block definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| out_planes (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | |||||
| """ | |||||
| def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| super(ConvBNReLU, self).__init__() | |||||
| padding = (kernel_size - 1) // 2 | |||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) | |||||
| else: | |||||
| if platform == "Ascend": | |||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) | |||||
| elif platform == "GPU": | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, | |||||
| group=in_planes, pad_mode='pad', padding=padding) | |||||
| layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] | |||||
| self.features = nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| output = self.features(x) | |||||
| return output | |||||
| class InvertedResidual(nn.Cell): | |||||
| """ | |||||
| Mobilenetv2 residual block definition. | |||||
| Args: | |||||
| inp (int): Input channel. | |||||
| oup (int): Output channel. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| expand_ratio (int): expand ration of input channel | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResidualBlock(3, 256, 1, 1) | |||||
| """ | |||||
| def __init__(self, platform, inp, oup, stride, expand_ratio): | |||||
| super(InvertedResidual, self).__init__() | |||||
| assert stride in [1, 2] | |||||
| hidden_dim = int(round(inp * expand_ratio)) | |||||
| self.use_res_connect = stride == 1 and inp == oup | |||||
| layers = [] | |||||
| if expand_ratio != 1: | |||||
| layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) | |||||
| layers.extend([ | |||||
| # dw | |||||
| ConvBNReLU(platform, hidden_dim, hidden_dim, | |||||
| stride=stride, groups=hidden_dim), | |||||
| # pw-linear | |||||
| nn.Conv2d(hidden_dim, oup, kernel_size=1, | |||||
| stride=1, has_bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| ]) | |||||
| self.conv = nn.SequentialCell(layers) | |||||
| self.add = TensorAdd() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| x = self.conv(x) | |||||
| if self.use_res_connect: | |||||
| return self.add(identity, x) | |||||
| return x | |||||
| class MobileNetV2(nn.Cell): | |||||
| """ | |||||
| MobileNetV2 architecture. | |||||
| Args: | |||||
| class_num (Cell): number of classes. | |||||
| width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||||
| has_dropout (bool): Is dropout used. Default is false | |||||
| inverted_residual_setting (list): Inverted residual settings. Default is None | |||||
| round_nearest (list): Channel round to . Default is 8 | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> MobileNetV2(num_classes=1000) | |||||
| """ | |||||
| def __init__(self, platform, num_classes=1000, width_mult=1., | |||||
| has_dropout=False, inverted_residual_setting=None, round_nearest=8): | |||||
| super(MobileNetV2, self).__init__() | |||||
| block = InvertedResidual | |||||
| input_channel = 32 | |||||
| last_channel = 1280 | |||||
| # setting of inverted residual blocks | |||||
| self.cfgs = inverted_residual_setting | |||||
| if inverted_residual_setting is None: | |||||
| self.cfgs = [ | |||||
| # t, c, n, s | |||||
| [1, 16, 1, 1], | |||||
| [6, 24, 2, 2], | |||||
| [6, 32, 3, 2], | |||||
| [6, 64, 4, 2], | |||||
| [6, 96, 3, 1], | |||||
| [6, 160, 3, 2], | |||||
| [6, 320, 1, 1], | |||||
| ] | |||||
| # building first layer | |||||
| input_channel = _make_divisible(input_channel * width_mult, round_nearest) | |||||
| self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | |||||
| features = [ConvBNReLU(platform, 3, input_channel, stride=2)] | |||||
| # building inverted residual blocks | |||||
| for t, c, n, s in self.cfgs: | |||||
| output_channel = _make_divisible(c * width_mult, round_nearest) | |||||
| for i in range(n): | |||||
| stride = s if i == 0 else 1 | |||||
| features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) | |||||
| input_channel = output_channel | |||||
| # building last several layers | |||||
| features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) | |||||
| # make it nn.CellList | |||||
| self.features = nn.SequentialCell(features) | |||||
| # mobilenet head | |||||
| head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else | |||||
| [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) | |||||
| self.head = nn.SequentialCell(head) | |||||
| self._initialize_weights() | |||||
| def construct(self, x): | |||||
| x = self.features(x) | |||||
| x = self.head(x) | |||||
| return x | |||||
| def _initialize_weights(self): | |||||
| """ | |||||
| Initialize weights. | |||||
| Args: | |||||
| Returns: | |||||
| None. | |||||
| Examples: | |||||
| >>> _initialize_weights() | |||||
| """ | |||||
| for _, m in self.cells_and_names(): | |||||
| if isinstance(m, (nn.Conv2d, DepthwiseConv)): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | |||||
| m.gamma.set_parameter_data( | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v2(**kwargs): | |||||
| """ | |||||
| Constructs a MobileNet V2 model | |||||
| """ | |||||
| return MobileNetV2(**kwargs) | |||||
| @@ -1,390 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MobileNetV3 model define""" | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import Tensor | |||||
| __all__ = ['mobilenet_v3_large', | |||||
| 'mobilenet_v3_small'] | |||||
| def _make_divisible(x, divisor=8): | |||||
| return int(np.ceil(x * 1. / divisor) * divisor) | |||||
| class Activation(nn.Cell): | |||||
| """ | |||||
| Activation definition. | |||||
| Args: | |||||
| act_func(string): activation name. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| """ | |||||
| def __init__(self, act_func): | |||||
| super(Activation, self).__init__() | |||||
| if act_func == 'relu': | |||||
| self.act = nn.ReLU() | |||||
| elif act_func == 'relu6': | |||||
| self.act = nn.ReLU6() | |||||
| elif act_func in ('hsigmoid', 'hard_sigmoid'): | |||||
| self.act = nn.HSigmoid() | |||||
| elif act_func in ('hswish', 'hard_swish'): | |||||
| self.act = nn.HSwish() | |||||
| else: | |||||
| raise NotImplementedError | |||||
| def construct(self, x): | |||||
| return self.act(x) | |||||
| class GlobalAvgPooling(nn.Cell): | |||||
| """ | |||||
| Global avg pooling definition. | |||||
| Args: | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> GlobalAvgPooling() | |||||
| """ | |||||
| def __init__(self, keep_dims=False): | |||||
| super(GlobalAvgPooling, self).__init__() | |||||
| self.mean = P.ReduceMean(keep_dims=keep_dims) | |||||
| def construct(self, x): | |||||
| x = self.mean(x, (2, 3)) | |||||
| return x | |||||
| class SE(nn.Cell): | |||||
| """ | |||||
| SE warpper definition. | |||||
| Args: | |||||
| num_out (int): Output channel. | |||||
| ratio (int): middle output ratio. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> SE(4) | |||||
| """ | |||||
| def __init__(self, num_out, ratio=4): | |||||
| super(SE, self).__init__() | |||||
| num_mid = _make_divisible(num_out // ratio) | |||||
| self.pool = GlobalAvgPooling(keep_dims=True) | |||||
| self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.act1 = Activation('relu') | |||||
| self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.act2 = Activation('hsigmoid') | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x): | |||||
| out = self.pool(x) | |||||
| out = self.conv1(out) | |||||
| out = self.act1(out) | |||||
| out = self.conv2(out) | |||||
| out = self.act2(out) | |||||
| out = self.mul(x, out) | |||||
| return out | |||||
| class Unit(nn.Cell): | |||||
| """ | |||||
| Unit warpper definition. | |||||
| Args: | |||||
| num_in (int): Input channel. | |||||
| num_out (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| padding (int): Padding number. | |||||
| num_groups (int): Output num group. | |||||
| use_act (bool): Used activation or not. | |||||
| act_type (string): Activation type. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> Unit(3, 3) | |||||
| """ | |||||
| def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1, | |||||
| use_act=True, act_type='relu'): | |||||
| super(Unit, self).__init__() | |||||
| self.conv = nn.Conv2d(in_channels=num_in, | |||||
| out_channels=num_out, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| group=num_groups, | |||||
| has_bias=False, | |||||
| pad_mode='pad') | |||||
| self.bn = nn.BatchNorm2d(num_out) | |||||
| self.use_act = use_act | |||||
| self.act = Activation(act_type) if use_act else None | |||||
| def construct(self, x): | |||||
| out = self.conv(x) | |||||
| out = self.bn(out) | |||||
| if self.use_act: | |||||
| out = self.act(out) | |||||
| return out | |||||
| class ResUnit(nn.Cell): | |||||
| """ | |||||
| ResUnit warpper definition. | |||||
| Args: | |||||
| num_in (int): Input channel. | |||||
| num_mid (int): Middle channel. | |||||
| num_out (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| act_type (str): Activation type. | |||||
| use_se (bool): Use SE warpper or not. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResUnit(16, 3, 1, 1) | |||||
| """ | |||||
| def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): | |||||
| super(ResUnit, self).__init__() | |||||
| self.use_se = use_se | |||||
| self.first_conv = (num_out != num_mid) | |||||
| self.use_short_cut_conv = True | |||||
| if self.first_conv: | |||||
| self.expand = Unit(num_in, num_mid, kernel_size=1, | |||||
| stride=1, padding=0, act_type=act_type) | |||||
| else: | |||||
| self.expand = None | |||||
| self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride, | |||||
| padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid) | |||||
| if use_se: | |||||
| self.se = SE(num_mid) | |||||
| self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1, | |||||
| padding=0, act_type=act_type, use_act=False) | |||||
| if num_in != num_out or stride != 1: | |||||
| self.use_short_cut_conv = False | |||||
| self.add = P.TensorAdd() if self.use_short_cut_conv else None | |||||
| def construct(self, x): | |||||
| if self.first_conv: | |||||
| out = self.expand(x) | |||||
| else: | |||||
| out = x | |||||
| out = self.conv1(out) | |||||
| if self.use_se: | |||||
| out = self.se(out) | |||||
| out = self.conv2(out) | |||||
| if self.use_short_cut_conv: | |||||
| out = self.add(x, out) | |||||
| return out | |||||
| def _get_pad(self, kernel_size): | |||||
| """set the padding number""" | |||||
| pad = 0 | |||||
| if kernel_size == 1: | |||||
| pad = 0 | |||||
| elif kernel_size == 3: | |||||
| pad = 1 | |||||
| elif kernel_size == 5: | |||||
| pad = 2 | |||||
| elif kernel_size == 7: | |||||
| pad = 3 | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return pad | |||||
| class MobileNetV3(nn.Cell): | |||||
| """ | |||||
| MobileNetV3 architecture. | |||||
| Args: | |||||
| model_cfgs (Cell): number of classes. | |||||
| num_classes (int): Output number classes. | |||||
| multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||||
| final_drop (float): Dropout number. | |||||
| round_nearest (list): Channel round to . Default is 8. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> MobileNetV3(num_classes=1000) | |||||
| """ | |||||
| def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): | |||||
| super(MobileNetV3, self).__init__() | |||||
| self.cfgs = model_cfgs['cfg'] | |||||
| self.inplanes = 16 | |||||
| self.features = [] | |||||
| first_conv_in_channel = 3 | |||||
| first_conv_out_channel = _make_divisible(multiplier * self.inplanes) | |||||
| self.features.append(nn.Conv2d(in_channels=first_conv_in_channel, | |||||
| out_channels=first_conv_out_channel, | |||||
| kernel_size=3, padding=1, stride=2, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(nn.BatchNorm2d(first_conv_out_channel)) | |||||
| self.features.append(Activation('hswish')) | |||||
| for layer_cfg in self.cfgs: | |||||
| self.features.append(self._make_layer(kernel_size=layer_cfg[0], | |||||
| exp_ch=_make_divisible(multiplier * layer_cfg[1]), | |||||
| out_channel=_make_divisible(multiplier * layer_cfg[2]), | |||||
| use_se=layer_cfg[3], | |||||
| act_func=layer_cfg[4], | |||||
| stride=layer_cfg[5])) | |||||
| output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"]) | |||||
| self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]), | |||||
| out_channels=output_channel, | |||||
| kernel_size=1, padding=0, stride=1, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(nn.BatchNorm2d(output_channel)) | |||||
| self.features.append(Activation('hswish')) | |||||
| self.features.append(GlobalAvgPooling(keep_dims=True)) | |||||
| self.features.append(nn.Conv2d(in_channels=output_channel, | |||||
| out_channels=model_cfgs['cls_ch_expand'], | |||||
| kernel_size=1, padding=0, stride=1, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(Activation('hswish')) | |||||
| if final_drop > 0: | |||||
| self.features.append((nn.Dropout(final_drop))) | |||||
| # make it nn.CellList | |||||
| self.features = nn.SequentialCell(self.features) | |||||
| self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], | |||||
| out_channels=num_classes, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.squeeze = P.Squeeze(axis=(2, 3)) | |||||
| self._initialize_weights() | |||||
| def construct(self, x): | |||||
| x = self.features(x) | |||||
| x = self.output(x) | |||||
| x = self.squeeze(x) | |||||
| return x | |||||
| def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): | |||||
| mid_planes = exp_ch | |||||
| out_planes = out_channel | |||||
| #num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): | |||||
| layer = ResUnit(self.inplanes, mid_planes, out_planes, | |||||
| kernel_size, stride=stride, act_type=act_func, use_se=use_se) | |||||
| self.inplanes = out_planes | |||||
| return layer | |||||
| def _initialize_weights(self): | |||||
| """ | |||||
| Initialize weights. | |||||
| Args: | |||||
| Returns: | |||||
| None. | |||||
| Examples: | |||||
| >>> _initialize_weights() | |||||
| """ | |||||
| for _, m in self.cells_and_names(): | |||||
| if isinstance(m, (nn.Conv2d)): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | |||||
| m.gamma.set_parameter_data( | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_parameter_data( | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | |||||
| 0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| def mobilenet_v3(model_name, **kwargs): | |||||
| """ | |||||
| Constructs a MobileNet V2 model | |||||
| """ | |||||
| model_cfgs = { | |||||
| "large": { | |||||
| "cfg": [ | |||||
| # k, exp, c, se, nl, s, | |||||
| [3, 16, 16, False, 'relu', 1], | |||||
| [3, 64, 24, False, 'relu', 2], | |||||
| [3, 72, 24, False, 'relu', 1], | |||||
| [5, 72, 40, True, 'relu', 2], | |||||
| [5, 120, 40, True, 'relu', 1], | |||||
| [5, 120, 40, True, 'relu', 1], | |||||
| [3, 240, 80, False, 'hswish', 2], | |||||
| [3, 200, 80, False, 'hswish', 1], | |||||
| [3, 184, 80, False, 'hswish', 1], | |||||
| [3, 184, 80, False, 'hswish', 1], | |||||
| [3, 480, 112, True, 'hswish', 1], | |||||
| [3, 672, 112, True, 'hswish', 1], | |||||
| [5, 672, 160, True, 'hswish', 2], | |||||
| [5, 960, 160, True, 'hswish', 1], | |||||
| [5, 960, 160, True, 'hswish', 1]], | |||||
| "cls_ch_squeeze": 960, | |||||
| "cls_ch_expand": 1280, | |||||
| }, | |||||
| "small": { | |||||
| "cfg": [ | |||||
| # k, exp, c, se, nl, s, | |||||
| [3, 16, 16, True, 'relu', 2], | |||||
| [3, 72, 24, False, 'relu', 2], | |||||
| [3, 88, 24, False, 'relu', 1], | |||||
| [5, 96, 40, True, 'hswish', 2], | |||||
| [5, 240, 40, True, 'hswish', 1], | |||||
| [5, 240, 40, True, 'hswish', 1], | |||||
| [5, 120, 48, True, 'hswish', 1], | |||||
| [5, 144, 48, True, 'hswish', 1], | |||||
| [5, 288, 96, True, 'hswish', 2], | |||||
| [5, 576, 96, True, 'hswish', 1], | |||||
| [5, 576, 96, True, 'hswish', 1]], | |||||
| "cls_ch_squeeze": 576, | |||||
| "cls_ch_expand": 1280, | |||||
| } | |||||
| } | |||||
| return MobileNetV3(model_cfgs[model_name], **kwargs) | |||||
| mobilenet_v3_large = partial(mobilenet_v3, model_name="large") | |||||
| mobilenet_v3_small = partial(mobilenet_v3, model_name="small") | |||||
| @@ -279,7 +279,7 @@ class FakeQuantWithMinMax(Cell): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization by layer or channel. Default: False. | |||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| channel_axis (int): Quantization by channel axis. Default: 1. | channel_axis (int): Quantization by channel axis. Default: 1. | ||||
| out_channels (int): declarate the min and max channel size, Default: 1. | out_channels (int): declarate the min and max channel size, Default: 1. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| @@ -407,7 +407,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | ||||
| fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | ||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -584,7 +584,7 @@ class Conv2dQuant(Cell): | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. | bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -694,7 +694,7 @@ class DenseQuant(Cell): | |||||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | ||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -797,6 +797,7 @@ class ReLUQuant(_QuantActivation): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -816,6 +817,7 @@ class ReLUQuant(_QuantActivation): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(ReLUQuant, self).__init__() | super(ReLUQuant, self).__init__() | ||||
| @@ -824,6 +826,7 @@ class ReLUQuant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -850,6 +853,7 @@ class ReLU6Quant(_QuantActivation): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -869,6 +873,7 @@ class ReLU6Quant(_QuantActivation): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(ReLU6Quant, self).__init__() | super(ReLU6Quant, self).__init__() | ||||
| @@ -877,6 +882,7 @@ class ReLU6Quant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -900,6 +906,7 @@ class HSwishQuant(_QuantActivation): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -919,6 +926,7 @@ class HSwishQuant(_QuantActivation): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(HSwishQuant, self).__init__() | super(HSwishQuant, self).__init__() | ||||
| @@ -927,6 +935,7 @@ class HSwishQuant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -935,6 +944,7 @@ class HSwishQuant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -959,6 +969,7 @@ class HSigmoidQuant(_QuantActivation): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -978,6 +989,7 @@ class HSigmoidQuant(_QuantActivation): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(HSigmoidQuant, self).__init__() | super(HSigmoidQuant, self).__init__() | ||||
| @@ -986,6 +998,7 @@ class HSigmoidQuant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, | self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, | ||||
| @@ -993,6 +1006,7 @@ class HSigmoidQuant(_QuantActivation): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -1017,6 +1031,7 @@ class TensorAddQuant(Cell): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -1037,6 +1052,7 @@ class TensorAddQuant(Cell): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(TensorAddQuant, self).__init__() | super(TensorAddQuant, self).__init__() | ||||
| @@ -1045,6 +1061,7 @@ class TensorAddQuant(Cell): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -1066,6 +1083,7 @@ class MulQuant(Cell): | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | quant_delay (int): Quantization delay parameters according by global step. Default: 0. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -1081,6 +1099,7 @@ class MulQuant(Cell): | |||||
| num_bits=8, | num_bits=8, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | narrow_range=False): | ||||
| super(MulQuant, self).__init__() | super(MulQuant, self).__init__() | ||||
| @@ -1089,6 +1108,7 @@ class MulQuant(Cell): | |||||
| num_bits=num_bits, | num_bits=num_bits, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| ema=True, | ema=True, | ||||
| per_channel=per_channel, | |||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """LossMonitor Callback class.""" | """LossMonitor Callback class.""" | ||||
| import time | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -31,32 +32,62 @@ class LossMonitor(Callback): | |||||
| Args: | Args: | ||||
| per_print_times (int): Print loss every times. Default: 1. | per_print_times (int): Print loss every times. Default: 1. | ||||
| lr_init (numpy array): train learning rate. Default: None. | |||||
| Raises: | Raises: | ||||
| ValueError: If print_step is not int or less than zero. | ValueError: If print_step is not int or less than zero. | ||||
| Examples: | |||||
| >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) | |||||
| """ | """ | ||||
| def __init__(self, per_print_times=1): | |||||
| def __init__(self, per_print_times=1, lr_init=None): | |||||
| super(LossMonitor, self).__init__() | super(LossMonitor, self).__init__() | ||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | if not isinstance(per_print_times, int) or per_print_times < 0: | ||||
| raise ValueError("print_step must be int and >= 0.") | raise ValueError("print_step must be int and >= 0.") | ||||
| self._per_print_times = per_print_times | self._per_print_times = per_print_times | ||||
| self.lr_init = lr_init | |||||
| def epoch_begin(self, run_context): | |||||
| self.losses = [] | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||||
| print("Epoch time: {:5.3f}, per step time: {:5.3f}, " | |||||
| "avg loss: {:5.3f}".format(epoch_mseconds, | |||||
| per_step_mseconds, | |||||
| np.mean(self.losses))) | |||||
| print("*" * 60) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| cb_params = run_context.original_args() | cb_params = run_context.original_args() | ||||
| loss = cb_params.net_outputs | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| step_loss = cb_params.net_outputs | |||||
| if isinstance(loss, (tuple, list)): | |||||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||||
| loss = loss[0] | |||||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||||
| step_loss = step_loss[0] | |||||
| if isinstance(step_loss, Tensor): | |||||
| step_loss = np.mean(step_loss.asnumpy()) | |||||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||||
| loss = np.mean(loss.asnumpy()) | |||||
| self.losses.append(step_loss) | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | |||||
| raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | |||||
| "Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | |||||
| cur_step_in_epoch, cb_params.batch_num)) | |||||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) | |||||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | |||||
| "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( | |||||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | |||||
| cur_step_in_epoch, cb_params.batch_num, | |||||
| step_loss, np.mean(self.losses), | |||||
| step_mseconds), flush=True) | |||||
| @@ -32,4 +32,4 @@ class TimeMonitor(Callback): | |||||
| def epoch_end(self, run_context): | def epoch_end(self, run_context): | ||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | epoch_mseconds = (time.time() - self.epoch_time) * 1000 | ||||
| per_step_mseconds = epoch_mseconds / self.data_size | per_step_mseconds = epoch_mseconds / self.data_size | ||||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||||
| print("Epoch time: {:5.3f}, per step time: {:5.3f}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||||
| @@ -32,6 +32,7 @@ from ...ops.operations import _inner_ops as inner | |||||
| from ...train import serialization | from ...train import serialization | ||||
| from . import quant_utils | from . import quant_utils | ||||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | ||||
| nn.ReLU6: quant.ReLU6Quant, | nn.ReLU6: quant.ReLU6Quant, | ||||
| nn.HSigmoid: quant.HSigmoidQuant, | nn.HSigmoid: quant.HSigmoidQuant, | ||||
| @@ -61,14 +62,17 @@ class _AddFakeQuantAfterSubCell(nn.Cell): | |||||
| Add FakeQuant after of the sub Cell. | Add FakeQuant after of the sub Cell. | ||||
| """ | """ | ||||
| def __init__(self, subcell, quant_delay=0, num_bits=8): | |||||
| def __init__(self, subcell, **kwargs): | |||||
| super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) | super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) | ||||
| self.subcell = subcell | self.subcell = subcell | ||||
| self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, | self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, | ||||
| max_init=6, | max_init=6, | ||||
| num_bits=num_bits, | |||||
| quant_delay=quant_delay, | |||||
| ema=True) | |||||
| ema=True, | |||||
| num_bits=kwargs["num_bits"], | |||||
| quant_delay=kwargs["quant_delay"], | |||||
| per_channel=kwargs["per_channel"], | |||||
| symmetric=kwargs["symmetric"], | |||||
| narrow_range=kwargs["narrow_range"]) | |||||
| def construct(self, *data): | def construct(self, *data): | ||||
| output = self.subcell(*data) | output = self.subcell(*data) | ||||
| @@ -82,30 +86,20 @@ class ConvertToQuantNetwork: | |||||
| """ | """ | ||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | ||||
| def __init__(self, | |||||
| network, | |||||
| quant_delay=0, | |||||
| bn_fold=False, | |||||
| freeze_bn=0, | |||||
| weight_bits=8, | |||||
| act_bits=8, | |||||
| per_channel=False, | |||||
| symmetric=False, | |||||
| narrow_range=False): | |||||
| self.network = validator.check_isinstance( | |||||
| 'network', network, (nn.Cell,)) | |||||
| self.quant_delay = validator.check_integer( | |||||
| "quant delay", quant_delay, 0, Rel.GE) | |||||
| self.freeze_bn = validator.check_integer( | |||||
| "freeze bn", freeze_bn, 0, Rel.GE) | |||||
| self.weight_bits = validator.check_integer( | |||||
| "weights bit", weight_bits, 0, Rel.GE) | |||||
| self.act_bits = validator.check_integer( | |||||
| "activations bit", act_bits, 0, Rel.GE) | |||||
| self.bn_fold = validator.check_bool("bn fold", bn_fold) | |||||
| self.per_channel = validator.check_bool("per channel", per_channel) | |||||
| self.symmetric = validator.check_bool("symmetric", symmetric) | |||||
| self.narrow_range = validator.check_bool("narrow range", narrow_range) | |||||
| def __init__(self, **kwargs): | |||||
| self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||||
| self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||||
| self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||||
| self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"]) | |||||
| self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||||
| self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||||
| self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||||
| self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0]) | |||||
| self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1]) | |||||
| self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0]) | |||||
| self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1]) | |||||
| self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0]) | |||||
| self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1]) | |||||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | ||||
| quant.DenseBnAct: self._convert_dense} | quant.DenseBnAct: self._convert_dense} | ||||
| @@ -153,7 +147,12 @@ class ConvertToQuantNetwork: | |||||
| add_list.append((name, attr)) | add_list.append((name, attr)) | ||||
| for name, prim_op in add_list: | for name, prim_op in add_list: | ||||
| prefix = name | prefix = name | ||||
| add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant() | |||||
| add_quant = _AddFakeQuantAfterSubCell(prim_op, | |||||
| num_bits=self.act_bits, | |||||
| quant_delay=self.act_delay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | ||||
| add_quant.update_parameters_name(prefix + '.') | add_quant.update_parameters_name(prefix + '.') | ||||
| del network.__dict__[name] | del network.__dict__[name] | ||||
| @@ -177,13 +176,13 @@ class ConvertToQuantNetwork: | |||||
| group=conv_inner.group, | group=conv_inner.group, | ||||
| eps=bn_inner.eps, | eps=bn_inner.eps, | ||||
| momentum=bn_inner.momentum, | momentum=bn_inner.momentum, | ||||
| quant_delay=self.quant_delay, | |||||
| quant_delay=self.weight_qdelay, | |||||
| freeze_bn=self.freeze_bn, | freeze_bn=self.freeze_bn, | ||||
| per_channel=self.per_channel, | |||||
| per_channel=self.weight_channel, | |||||
| num_bits=self.weight_bits, | num_bits=self.weight_bits, | ||||
| fake=True, | fake=True, | ||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range) | |||||
| symmetric=self.weight_symmetric, | |||||
| narrow_range=self.weight_range) | |||||
| del subcell.batchnorm | del subcell.batchnorm | ||||
| subcell.batchnorm = None | subcell.batchnorm = None | ||||
| subcell.has_bn = False | subcell.has_bn = False | ||||
| @@ -197,18 +196,22 @@ class ConvertToQuantNetwork: | |||||
| dilation=conv_inner.dilation, | dilation=conv_inner.dilation, | ||||
| group=conv_inner.group, | group=conv_inner.group, | ||||
| has_bias=conv_inner.has_bias, | has_bias=conv_inner.has_bias, | ||||
| quant_delay=self.quant_delay, | |||||
| per_channel=self.per_channel, | |||||
| quant_delay=self.weight_qdelay, | |||||
| per_channel=self.weight_channel, | |||||
| num_bits=self.weight_bits, | num_bits=self.weight_bits, | ||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range) | |||||
| symmetric=self.weight_symmetric, | |||||
| narrow_range=self.weight_range) | |||||
| subcell.conv = conv_inner | subcell.conv = conv_inner | ||||
| if subcell.has_act and subcell.activation is not None: | if subcell.has_act and subcell.activation is not None: | ||||
| subcell.activation = self._convert_activation(subcell.activation) | subcell.activation = self._convert_activation(subcell.activation) | ||||
| else: | else: | ||||
| subcell.has_act = True | subcell.has_act = True | ||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, | |||||
| quant_delay=self.quant_delay) | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| num_bits=self.act_bits, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | return subcell | ||||
| def _convert_dense(self, subcell): | def _convert_dense(self, subcell): | ||||
| @@ -219,16 +222,22 @@ class ConvertToQuantNetwork: | |||||
| dense_inner = quant.DenseQuant(dense_inner.in_channels, | dense_inner = quant.DenseQuant(dense_inner.in_channels, | ||||
| dense_inner.out_channels, | dense_inner.out_channels, | ||||
| has_bias=dense_inner.has_bias, | has_bias=dense_inner.has_bias, | ||||
| quant_delay=self.quant_delay, | |||||
| per_channel=self.per_channel, | |||||
| num_bits=self.weight_bits) | |||||
| num_bits=self.weight_bits, | |||||
| quant_delay=self.weight_qdelay, | |||||
| per_channel=self.weight_channel, | |||||
| symmetric=self.weight_symmetric, | |||||
| narrow_range=self.weight_range) | |||||
| subcell.dense = dense_inner | subcell.dense = dense_inner | ||||
| if subcell.has_act and subcell.activation is not None: | if subcell.has_act and subcell.activation is not None: | ||||
| subcell.activation = self._convert_activation(subcell.activation) | subcell.activation = self._convert_activation(subcell.activation) | ||||
| else: | else: | ||||
| subcell.has_act = True | subcell.has_act = True | ||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, | |||||
| quant_delay=self.quant_delay) | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| num_bits=self.act_bits, | |||||
| quant_delay=self.act_delay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | return subcell | ||||
| def _convert_activation(self, activation): | def _convert_activation(self, activation): | ||||
| @@ -236,7 +245,11 @@ class ConvertToQuantNetwork: | |||||
| if act_class not in _ACTIVATION_MAP: | if act_class not in _ACTIVATION_MAP: | ||||
| raise ValueError( | raise ValueError( | ||||
| "Unsupported activation in auto Quant: ", act_class) | "Unsupported activation in auto Quant: ", act_class) | ||||
| return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) | |||||
| return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.weight_symmetric, | |||||
| narrow_range=self.weight_range) | |||||
| class ExportQuantNetworkDeploy: | class ExportQuantNetworkDeploy: | ||||
| @@ -381,32 +394,57 @@ def export_geir(network, *inputs, file_name): | |||||
| def convert_quant_network(network, | def convert_quant_network(network, | ||||
| quant_delay=0, | |||||
| bn_fold=False, | bn_fold=False, | ||||
| freeze_bn=0, | freeze_bn=0, | ||||
| weight_bits=8, | |||||
| act_bits=8, | |||||
| per_channel=False, | |||||
| symmetric=False, | |||||
| narrow_range=False | |||||
| quant_delay=(0, 0), | |||||
| num_bits=(8, 8), | |||||
| per_channel=(False, False), | |||||
| symmetric=(False, False), | |||||
| narrow_range=(False, False) | |||||
| ): | ): | ||||
| r""" | r""" | ||||
| Create aware quantizaiton training network. | Create aware quantizaiton training network. | ||||
| Args: | Args: | ||||
| network (Cell): Obtain a pipeline through network for saving graph summary. | network (Cell): Obtain a pipeline through network for saving graph summary. | ||||
| quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. | |||||
| quant_delay (int): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: [0, 0] | |||||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | ||||
| freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0. | |||||
| weight_bits (int): Number of bits to use for quantizing weights. Default: 8. | |||||
| act_bits (int): Number of bits to use for quantizing activations. Default: 8. | |||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||||
| freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. | |||||
| num_bits (list of int): Number of bits to use for quantizing weights and activations. The first | |||||
| element represent weights and second element represent data flow. Default: [8, 8] | |||||
| per_channel (list of bool): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | |||||
| and second element represent data flow. Default: [False, False] | |||||
| symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on assymmetric. The first element represent weights and second | |||||
| element represent data flow. Default: [False, False] | |||||
| narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base | |||||
| on narrow range otherwise base on off narrow range. The first element represent weights and | |||||
| second element represent data flow. Default: [False, False] | |||||
| Returns: | Returns: | ||||
| Cell, Network which has change to aware quantization training network. | |||||
| Cell, Network which has change to aware quantization training network cell. | |||||
| """ | """ | ||||
| net = ConvertToQuantNetwork( | |||||
| network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range) | |||||
| def convert2list(name, value): | |||||
| if not isinstance(value, list) and not isinstance(value, tuple): | |||||
| value = [value] | |||||
| elif len(value) > 2: | |||||
| raise ValueError("input `{}` len should less then 2".format(name)) | |||||
| return value | |||||
| quant_delay = convert2list("quant delay", quant_delay) | |||||
| num_bits = convert2list("num bits", num_bits) | |||||
| per_channel = convert2list("per channel", per_channel) | |||||
| symmetric = convert2list("symmetric", symmetric) | |||||
| narrow_range = convert2list("narrow range", narrow_range) | |||||
| net = ConvertToQuantNetwork(network=network, | |||||
| quant_delay=quant_delay, | |||||
| bn_fold=bn_fold, | |||||
| freeze_bn=freeze_bn, | |||||
| num_bits=num_bits, | |||||
| per_channel=per_channel, | |||||
| symmetric=symmetric, | |||||
| narrow_range=narrow_range) | |||||
| return net.run() | return net.run() | ||||
| @@ -54,4 +54,4 @@ if __name__ == "__main__": | |||||
| cfg.batch_size, | cfg.batch_size, | ||||
| status="test") | status="test") | ||||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("============== Accuracy:{} ==============".format(acc)) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -61,4 +61,4 @@ if __name__ == "__main__": | |||||
| cfg.batch_size, | cfg.batch_size, | ||||
| 1) | 1) | ||||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("============== Accuracy:{} ==============".format(acc)) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -78,4 +78,4 @@ if __name__ == '__main__': | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=False) | acc = model.eval(ds_eval, dataset_sink_mode=False) | ||||
| else: | else: | ||||
| acc = model.eval(ds_eval) | acc = model.eval(ds_eval) | ||||
| print("============== Accuracy:{} ==============".format(acc)) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -203,4 +203,4 @@ def test_train_and_eval_lenet(): | |||||
| print("============== Starting Testing ==============") | print("============== Starting Testing ==============") | ||||
| ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1) | ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1) | ||||
| acc = model.eval(ds_eval, dataset_sink_mode=True) | acc = model.eval(ds_eval, dataset_sink_mode=True) | ||||
| print("============== Accuracy:{} ==============".format(acc)) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -67,7 +67,7 @@ def test_qat_lenet(): | |||||
| img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | ||||
| net = LeNet5() | net = LeNet5() | ||||
| net = qat.convert_quant_network( | net = qat.convert_quant_network( | ||||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) | |||||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| for param in net.get_parameters(): | for param in net.get_parameters(): | ||||
| param.init_data() | param.init_data() | ||||
| @@ -79,7 +79,7 @@ def test_qat_mobile(): | |||||
| net = MobileNetV2() | net = MobileNetV2() | ||||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | ||||
| net = qat.convert_quant_network( | net = qat.convert_quant_network( | ||||
| net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8) | |||||
| net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| for param in net.get_parameters(): | for param in net.get_parameters(): | ||||
| param.init_data() | param.init_data() | ||||
| @@ -117,6 +117,7 @@ def test_loss_monitor_sink_mode(): | |||||
| """Test loss monitor sink mode.""" | """Test loss monitor sink mode.""" | ||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.cur_epoch_num = 4 | cb_params.cur_epoch_num = 4 | ||||
| cb_params.epoch_num = 4 | |||||
| cb_params.cur_step_num = 2 | cb_params.cur_step_num = 2 | ||||
| cb_params.batch_num = 2 | cb_params.batch_num = 2 | ||||
| cb_params.net_outputs = Tensor(2.0) | cb_params.net_outputs = Tensor(2.0) | ||||
| @@ -138,6 +139,7 @@ def test_loss_monitor_normal_mode(): | |||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| loss_cb = LossMonitor(1) | loss_cb = LossMonitor(1) | ||||
| cb_params.cur_epoch_num = 4 | cb_params.cur_epoch_num = 4 | ||||
| cb_params.epoch_num = 4 | |||||
| cb_params.cur_step_num = 1 | cb_params.cur_step_num = 1 | ||||
| cb_params.batch_num = 1 | cb_params.batch_num = 1 | ||||
| cb_params.net_outputs = Tensor(2.0) | cb_params.net_outputs = Tensor(2.0) | ||||