diff --git a/mindspore/model_zoo/mobilenetV2.py b/mindspore/model_zoo/mobilenetV2.py deleted file mode 100644 index 5b1b4cc5ef..0000000000 --- a/mindspore/model_zoo/mobilenetV2.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2 model define""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer - -__all__ = ['mobilenet_v2'] - - -def _make_divisible(v, divisor, min_value=None): - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=False) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - - def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) - else: - if platform == "Ascend": - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - elif platform == "GPU": - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, - group=in_planes, pad_mode='pad', padding=padding) - - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - -class InvertedResidual(nn.Cell): - """ - Mobilenetv2 residual block definition. - - Args: - inp (int): Input channel. - oup (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - expand_ratio (int): expand ration of input channel - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, 1, 1) - """ - - def __init__(self, platform, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(platform, hidden_dim, hidden_dim, - stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, kernel_size=1, - stride=1, has_bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() - self.cast = P.Cast() - - def construct(self, x): - identity = x - x = self.conv(x) - if self.use_res_connect: - return self.add(identity, x) - return x - - -class MobileNetV2(nn.Cell): - """ - MobileNetV2 architecture. - - Args: - class_num (Cell): number of classes. - width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. - has_dropout (bool): Is dropout used. Default is false - inverted_residual_setting (list): Inverted residual settings. Default is None - round_nearest (list): Channel round to . Default is 8 - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV2(num_classes=1000) - """ - - def __init__(self, platform, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(MobileNetV2, self).__init__() - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - # setting of inverted residual blocks - self.cfgs = inverted_residual_setting - if inverted_residual_setting is None: - self.cfgs = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(platform, 3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in self.cfgs: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) - # make it nn.CellList - self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) - self.head = nn.SequentialCell(head) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.head(x) - return x - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - -def mobilenet_v2(**kwargs): - """ - Constructs a MobileNet V2 model - """ - return MobileNetV2(**kwargs) diff --git a/mindspore/model_zoo/mobilenetV3.py b/mindspore/model_zoo/mobilenetV3.py deleted file mode 100644 index 61b63f9ea1..0000000000 --- a/mindspore/model_zoo/mobilenetV3.py +++ /dev/null @@ -1,390 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV3 model define""" -from functools import partial -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore import Tensor - - -__all__ = ['mobilenet_v3_large', - 'mobilenet_v3_small'] - - -def _make_divisible(x, divisor=8): - return int(np.ceil(x * 1. / divisor) * divisor) - - -class Activation(nn.Cell): - """ - Activation definition. - - Args: - act_func(string): activation name. - - Returns: - Tensor, output tensor. - """ - - def __init__(self, act_func): - super(Activation, self).__init__() - if act_func == 'relu': - self.act = nn.ReLU() - elif act_func == 'relu6': - self.act = nn.ReLU6() - elif act_func in ('hsigmoid', 'hard_sigmoid'): - self.act = nn.HSigmoid() - elif act_func in ('hswish', 'hard_swish'): - self.act = nn.HSwish() - else: - raise NotImplementedError - - def construct(self, x): - return self.act(x) - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self, keep_dims=False): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=keep_dims) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class SE(nn.Cell): - """ - SE warpper definition. - - Args: - num_out (int): Output channel. - ratio (int): middle output ratio. - - Returns: - Tensor, output tensor. - - Examples: - >>> SE(4) - """ - - def __init__(self, num_out, ratio=4): - super(SE, self).__init__() - num_mid = _make_divisible(num_out // ratio) - self.pool = GlobalAvgPooling(keep_dims=True) - self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid, - kernel_size=1, has_bias=True, pad_mode='pad') - self.act1 = Activation('relu') - self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out, - kernel_size=1, has_bias=True, pad_mode='pad') - self.act2 = Activation('hsigmoid') - self.mul = P.Mul() - - def construct(self, x): - out = self.pool(x) - out = self.conv1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.act2(out) - out = self.mul(x, out) - return out - - -class Unit(nn.Cell): - """ - Unit warpper definition. - - Args: - num_in (int): Input channel. - num_out (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - padding (int): Padding number. - num_groups (int): Output num group. - use_act (bool): Used activation or not. - act_type (string): Activation type. - - Returns: - Tensor, output tensor. - - Examples: - >>> Unit(3, 3) - """ - - def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1, - use_act=True, act_type='relu'): - super(Unit, self).__init__() - self.conv = nn.Conv2d(in_channels=num_in, - out_channels=num_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - group=num_groups, - has_bias=False, - pad_mode='pad') - self.bn = nn.BatchNorm2d(num_out) - self.use_act = use_act - self.act = Activation(act_type) if use_act else None - - def construct(self, x): - out = self.conv(x) - out = self.bn(out) - if self.use_act: - out = self.act(out) - return out - - -class ResUnit(nn.Cell): - """ - ResUnit warpper definition. - - Args: - num_in (int): Input channel. - num_mid (int): Middle channel. - num_out (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - act_type (str): Activation type. - use_se (bool): Use SE warpper or not. - - Returns: - Tensor, output tensor. - - Examples: - >>> ResUnit(16, 3, 1, 1) - """ - def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): - super(ResUnit, self).__init__() - self.use_se = use_se - self.first_conv = (num_out != num_mid) - self.use_short_cut_conv = True - - if self.first_conv: - self.expand = Unit(num_in, num_mid, kernel_size=1, - stride=1, padding=0, act_type=act_type) - else: - self.expand = None - self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride, - padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid) - if use_se: - self.se = SE(num_mid) - self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1, - padding=0, act_type=act_type, use_act=False) - if num_in != num_out or stride != 1: - self.use_short_cut_conv = False - self.add = P.TensorAdd() if self.use_short_cut_conv else None - - def construct(self, x): - if self.first_conv: - out = self.expand(x) - else: - out = x - out = self.conv1(out) - if self.use_se: - out = self.se(out) - out = self.conv2(out) - if self.use_short_cut_conv: - out = self.add(x, out) - return out - - def _get_pad(self, kernel_size): - """set the padding number""" - pad = 0 - if kernel_size == 1: - pad = 0 - elif kernel_size == 3: - pad = 1 - elif kernel_size == 5: - pad = 2 - elif kernel_size == 7: - pad = 3 - else: - raise NotImplementedError - return pad - - -class MobileNetV3(nn.Cell): - """ - MobileNetV3 architecture. - - Args: - model_cfgs (Cell): number of classes. - num_classes (int): Output number classes. - multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1. - final_drop (float): Dropout number. - round_nearest (list): Channel round to . Default is 8. - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV3(num_classes=1000) - """ - - def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): - super(MobileNetV3, self).__init__() - self.cfgs = model_cfgs['cfg'] - self.inplanes = 16 - self.features = [] - first_conv_in_channel = 3 - first_conv_out_channel = _make_divisible(multiplier * self.inplanes) - - self.features.append(nn.Conv2d(in_channels=first_conv_in_channel, - out_channels=first_conv_out_channel, - kernel_size=3, padding=1, stride=2, - has_bias=False, pad_mode='pad')) - self.features.append(nn.BatchNorm2d(first_conv_out_channel)) - self.features.append(Activation('hswish')) - for layer_cfg in self.cfgs: - self.features.append(self._make_layer(kernel_size=layer_cfg[0], - exp_ch=_make_divisible(multiplier * layer_cfg[1]), - out_channel=_make_divisible(multiplier * layer_cfg[2]), - use_se=layer_cfg[3], - act_func=layer_cfg[4], - stride=layer_cfg[5])) - output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"]) - self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]), - out_channels=output_channel, - kernel_size=1, padding=0, stride=1, - has_bias=False, pad_mode='pad')) - self.features.append(nn.BatchNorm2d(output_channel)) - self.features.append(Activation('hswish')) - self.features.append(GlobalAvgPooling(keep_dims=True)) - self.features.append(nn.Conv2d(in_channels=output_channel, - out_channels=model_cfgs['cls_ch_expand'], - kernel_size=1, padding=0, stride=1, - has_bias=False, pad_mode='pad')) - self.features.append(Activation('hswish')) - if final_drop > 0: - self.features.append((nn.Dropout(final_drop))) - - # make it nn.CellList - self.features = nn.SequentialCell(self.features) - self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], - out_channels=num_classes, - kernel_size=1, has_bias=True, pad_mode='pad') - self.squeeze = P.Squeeze(axis=(2, 3)) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.output(x) - x = self.squeeze(x) - return x - - def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): - mid_planes = exp_ch - out_planes = out_channel - #num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): - layer = ResUnit(self.inplanes, mid_planes, out_planes, - kernel_size, stride=stride, act_type=act_func, use_se=use_se) - self.inplanes = out_planes - return layer - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - -def mobilenet_v3(model_name, **kwargs): - """ - Constructs a MobileNet V2 model - """ - model_cfgs = { - "large": { - "cfg": [ - # k, exp, c, se, nl, s, - [3, 16, 16, False, 'relu', 1], - [3, 64, 24, False, 'relu', 2], - [3, 72, 24, False, 'relu', 1], - [5, 72, 40, True, 'relu', 2], - [5, 120, 40, True, 'relu', 1], - [5, 120, 40, True, 'relu', 1], - [3, 240, 80, False, 'hswish', 2], - [3, 200, 80, False, 'hswish', 1], - [3, 184, 80, False, 'hswish', 1], - [3, 184, 80, False, 'hswish', 1], - [3, 480, 112, True, 'hswish', 1], - [3, 672, 112, True, 'hswish', 1], - [5, 672, 160, True, 'hswish', 2], - [5, 960, 160, True, 'hswish', 1], - [5, 960, 160, True, 'hswish', 1]], - "cls_ch_squeeze": 960, - "cls_ch_expand": 1280, - }, - "small": { - "cfg": [ - # k, exp, c, se, nl, s, - [3, 16, 16, True, 'relu', 2], - [3, 72, 24, False, 'relu', 2], - [3, 88, 24, False, 'relu', 1], - [5, 96, 40, True, 'hswish', 2], - [5, 240, 40, True, 'hswish', 1], - [5, 240, 40, True, 'hswish', 1], - [5, 120, 48, True, 'hswish', 1], - [5, 144, 48, True, 'hswish', 1], - [5, 288, 96, True, 'hswish', 2], - [5, 576, 96, True, 'hswish', 1], - [5, 576, 96, True, 'hswish', 1]], - "cls_ch_squeeze": 576, - "cls_ch_expand": 1280, - } - } - return MobileNetV3(model_cfgs[model_name], **kwargs) - - -mobilenet_v3_large = partial(mobilenet_v3, model_name="large") -mobilenet_v3_small = partial(mobilenet_v3, model_name="small") diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 0ab1fe24df..6c01aa5404 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -279,7 +279,7 @@ class FakeQuantWithMinMax(Cell): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization by layer or channel. Default: False. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. channel_axis (int): Quantization by channel axis. Default: 1. out_channels (int): declarate the min and max channel size, Default: 1. quant_delay (int): Quantization delay parameters according by global step. Default: 0. @@ -407,7 +407,7 @@ class Conv2dBatchNormQuant(Cell): freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -584,7 +584,7 @@ class Conv2dQuant(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. quant_delay (int): Quantization delay parameters according by global step. Default: 0. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -694,7 +694,7 @@ class DenseQuant(Cell): activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -797,6 +797,7 @@ class ReLUQuant(_QuantActivation): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -816,6 +817,7 @@ class ReLUQuant(_QuantActivation): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(ReLUQuant, self).__init__() @@ -824,6 +826,7 @@ class ReLUQuant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -850,6 +853,7 @@ class ReLU6Quant(_QuantActivation): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -869,6 +873,7 @@ class ReLU6Quant(_QuantActivation): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(ReLU6Quant, self).__init__() @@ -877,6 +882,7 @@ class ReLU6Quant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -900,6 +906,7 @@ class HSwishQuant(_QuantActivation): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -919,6 +926,7 @@ class HSwishQuant(_QuantActivation): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(HSwishQuant, self).__init__() @@ -927,6 +935,7 @@ class HSwishQuant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -935,6 +944,7 @@ class HSwishQuant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -959,6 +969,7 @@ class HSigmoidQuant(_QuantActivation): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -978,6 +989,7 @@ class HSigmoidQuant(_QuantActivation): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(HSigmoidQuant, self).__init__() @@ -986,6 +998,7 @@ class HSigmoidQuant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, @@ -993,6 +1006,7 @@ class HSigmoidQuant(_QuantActivation): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -1017,6 +1031,7 @@ class TensorAddQuant(Cell): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1037,6 +1052,7 @@ class TensorAddQuant(Cell): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(TensorAddQuant, self).__init__() @@ -1045,6 +1061,7 @@ class TensorAddQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) @@ -1066,6 +1083,7 @@ class MulQuant(Cell): num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1081,6 +1099,7 @@ class MulQuant(Cell): num_bits=8, quant_delay=0, ema_decay=0.999, + per_channel=False, symmetric=False, narrow_range=False): super(MulQuant, self).__init__() @@ -1089,6 +1108,7 @@ class MulQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + per_channel=per_channel, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py index 15a095c5cb..22b1342873 100644 --- a/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/train/callback/_loss_monitor.py @@ -14,6 +14,7 @@ # ============================================================================ """LossMonitor Callback class.""" +import time import numpy as np from mindspore.common.tensor import Tensor @@ -31,32 +32,62 @@ class LossMonitor(Callback): Args: per_print_times (int): Print loss every times. Default: 1. + lr_init (numpy array): train learning rate. Default: None. Raises: ValueError: If print_step is not int or less than zero. + + Examples: + >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) """ - def __init__(self, per_print_times=1): + def __init__(self, per_print_times=1, lr_init=None): super(LossMonitor, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") self._per_print_times = per_print_times + self.lr_init = lr_init + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("Epoch time: {:5.3f}, per step time: {:5.3f}, " + "avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + print("*" * 60) + + def step_begin(self, run_context): + self.step_time = time.time() def step_end(self, run_context): cb_params = run_context.original_args() - loss = cb_params.net_outputs + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs - if isinstance(loss, (tuple, list)): - if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): - loss = loss[0] + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) - if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): - loss = np.mean(loss.asnumpy()) + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): + raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " + "Invalid loss, terminating training.".format( + cb_params.cur_epoch_num - 1, cb_params.epoch_num, + cur_step_in_epoch, cb_params.batch_num)) - if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): - raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( - cb_params.cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: - print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) + print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " + "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( + cb_params.cur_epoch_num - 1, cb_params.epoch_num, + cur_step_in_epoch, cb_params.batch_num, + step_loss, np.mean(self.losses), + step_mseconds), flush=True) diff --git a/mindspore/train/callback/_time_monitor.py b/mindspore/train/callback/_time_monitor.py index c810306d24..9fbdf83aa8 100644 --- a/mindspore/train/callback/_time_monitor.py +++ b/mindspore/train/callback/_time_monitor.py @@ -32,4 +32,4 @@ class TimeMonitor(Callback): def epoch_end(self, run_context): epoch_mseconds = (time.time() - self.epoch_time) * 1000 per_step_mseconds = epoch_mseconds / self.data_size - print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) + print("Epoch time: {:5.3f}, per step time: {:5.3f}".format(epoch_mseconds, per_step_mseconds), flush=True) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 01724f285c..a8f381425c 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -32,6 +32,7 @@ from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils + _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, nn.ReLU6: quant.ReLU6Quant, nn.HSigmoid: quant.HSigmoidQuant, @@ -61,14 +62,17 @@ class _AddFakeQuantAfterSubCell(nn.Cell): Add FakeQuant after of the sub Cell. """ - def __init__(self, subcell, quant_delay=0, num_bits=8): + def __init__(self, subcell, **kwargs): super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) self.subcell = subcell self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True) + ema=True, + num_bits=kwargs["num_bits"], + quant_delay=kwargs["quant_delay"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"]) def construct(self, *data): output = self.subcell(*data) @@ -82,30 +86,20 @@ class ConvertToQuantNetwork: """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] - def __init__(self, - network, - quant_delay=0, - bn_fold=False, - freeze_bn=0, - weight_bits=8, - act_bits=8, - per_channel=False, - symmetric=False, - narrow_range=False): - self.network = validator.check_isinstance( - 'network', network, (nn.Cell,)) - self.quant_delay = validator.check_integer( - "quant delay", quant_delay, 0, Rel.GE) - self.freeze_bn = validator.check_integer( - "freeze bn", freeze_bn, 0, Rel.GE) - self.weight_bits = validator.check_integer( - "weights bit", weight_bits, 0, Rel.GE) - self.act_bits = validator.check_integer( - "activations bit", act_bits, 0, Rel.GE) - self.bn_fold = validator.check_bool("bn fold", bn_fold) - self.per_channel = validator.check_bool("per channel", per_channel) - self.symmetric = validator.check_bool("symmetric", symmetric) - self.narrow_range = validator.check_bool("narrow range", narrow_range) + def __init__(self, **kwargs): + self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) + self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) + self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) + self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"]) + self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) + self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) + self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) + self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0]) + self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1]) + self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0]) + self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1]) + self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0]) + self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1]) self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, quant.DenseBnAct: self._convert_dense} @@ -153,7 +147,12 @@ class ConvertToQuantNetwork: add_list.append((name, attr)) for name, prim_op in add_list: prefix = name - add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant() + add_quant = _AddFakeQuantAfterSubCell(prim_op, + num_bits=self.act_bits, + quant_delay=self.act_delay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) add_quant.update_parameters_name(prefix + '.') del network.__dict__[name] @@ -177,13 +176,13 @@ class ConvertToQuantNetwork: group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, - quant_delay=self.quant_delay, + quant_delay=self.weight_qdelay, freeze_bn=self.freeze_bn, - per_channel=self.per_channel, + per_channel=self.weight_channel, num_bits=self.weight_bits, fake=True, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False @@ -197,18 +196,22 @@ class ConvertToQuantNetwork: dilation=conv_inner.dilation, group=conv_inner.group, has_bias=conv_inner.has_bias, - quant_delay=self.quant_delay, - per_channel=self.per_channel, + quant_delay=self.weight_qdelay, + per_channel=self.weight_channel, num_bits=self.weight_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) else: subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, - quant_delay=self.quant_delay) + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, + num_bits=self.act_bits, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) return subcell def _convert_dense(self, subcell): @@ -219,16 +222,22 @@ class ConvertToQuantNetwork: dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner.out_channels, has_bias=dense_inner.has_bias, - quant_delay=self.quant_delay, - per_channel=self.per_channel, - num_bits=self.weight_bits) + num_bits=self.weight_bits, + quant_delay=self.weight_qdelay, + per_channel=self.weight_channel, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) else: subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, - quant_delay=self.quant_delay) + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, + num_bits=self.act_bits, + quant_delay=self.act_delay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) return subcell def _convert_activation(self, activation): @@ -236,7 +245,11 @@ class ConvertToQuantNetwork: if act_class not in _ACTIVATION_MAP: raise ValueError( "Unsupported activation in auto Quant: ", act_class) - return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) + return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) class ExportQuantNetworkDeploy: @@ -381,32 +394,57 @@ def export_geir(network, *inputs, file_name): def convert_quant_network(network, - quant_delay=0, bn_fold=False, freeze_bn=0, - weight_bits=8, - act_bits=8, - per_channel=False, - symmetric=False, - narrow_range=False + quant_delay=(0, 0), + num_bits=(8, 8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False) ): r""" Create aware quantizaiton training network. Args: network (Cell): Obtain a pipeline through network for saving graph summary. - quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. + quant_delay (int): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: [0, 0] bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. - freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0. - weight_bits (int): Number of bits to use for quantizing weights. Default: 8. - act_bits (int): Number of bits to use for quantizing activations. Default: 8. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. + num_bits (list of int): Number of bits to use for quantizing weights and activations. The first + element represent weights and second element represent data flow. Default: [8, 8] + per_channel (list of bool): Quantization granularity based on layer or on channel. If `True` + then base on per channel otherwise base on per layer. The first element represent weights + and second element represent data flow. Default: [False, False] + symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on + symmetric otherwise base on assymmetric. The first element represent weights and second + element represent data flow. Default: [False, False] + narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base + on narrow range otherwise base on off narrow range. The first element represent weights and + second element represent data flow. Default: [False, False] Returns: - Cell, Network which has change to aware quantization training network. + Cell, Network which has change to aware quantization training network cell. """ - net = ConvertToQuantNetwork( - network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range) + def convert2list(name, value): + if not isinstance(value, list) and not isinstance(value, tuple): + value = [value] + elif len(value) > 2: + raise ValueError("input `{}` len should less then 2".format(name)) + return value + + quant_delay = convert2list("quant delay", quant_delay) + num_bits = convert2list("num bits", num_bits) + per_channel = convert2list("per channel", per_channel) + symmetric = convert2list("symmetric", symmetric) + narrow_range = convert2list("narrow range", narrow_range) + + net = ConvertToQuantNetwork(network=network, + quant_delay=quant_delay, + bn_fold=bn_fold, + freeze_bn=freeze_bn, + num_bits=num_bits, + per_channel=per_channel, + symmetric=symmetric, + narrow_range=narrow_range) return net.run() diff --git a/model_zoo/alexnet/eval.py b/model_zoo/alexnet/eval.py index c59284e05f..82544de43d 100644 --- a/model_zoo/alexnet/eval.py +++ b/model_zoo/alexnet/eval.py @@ -54,4 +54,4 @@ if __name__ == "__main__": cfg.batch_size, status="test") acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) - print("============== Accuracy:{} ==============".format(acc)) + print("============== {} ==============".format(acc)) diff --git a/model_zoo/lenet/eval.py b/model_zoo/lenet/eval.py index ee1f794695..445689623f 100644 --- a/model_zoo/lenet/eval.py +++ b/model_zoo/lenet/eval.py @@ -61,4 +61,4 @@ if __name__ == "__main__": cfg.batch_size, 1) acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) - print("============== Accuracy:{} ==============".format(acc)) + print("============== {} ==============".format(acc)) diff --git a/model_zoo/lstm/eval.py b/model_zoo/lstm/eval.py index 04e60d3a07..a9b81199c1 100644 --- a/model_zoo/lstm/eval.py +++ b/model_zoo/lstm/eval.py @@ -78,4 +78,4 @@ if __name__ == '__main__': acc = model.eval(ds_eval, dataset_sink_mode=False) else: acc = model.eval(ds_eval) - print("============== Accuracy:{} ==============".format(acc)) + print("============== {} ==============".format(acc)) diff --git a/tests/st/networks/test_gpu_lenet.py b/tests/st/networks/test_gpu_lenet.py index 038af92223..45774ce87e 100644 --- a/tests/st/networks/test_gpu_lenet.py +++ b/tests/st/networks/test_gpu_lenet.py @@ -203,4 +203,4 @@ def test_train_and_eval_lenet(): print("============== Starting Testing ==============") ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1) acc = model.eval(ds_eval, dataset_sink_mode=True) - print("============== Accuracy:{} ==============".format(acc)) + print("============== {} ==============".format(acc)) diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 6098354cb0..c9398be456 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -67,7 +67,7 @@ def test_qat_lenet(): img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) net = LeNet5() net = qat.convert_quant_network( - net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) + net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8) # should load the checkpoint. mock here for param in net.get_parameters(): param.init_data() @@ -79,7 +79,7 @@ def test_qat_mobile(): net = MobileNetV2() img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) net = qat.convert_quant_network( - net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8) + net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8) # should load the checkpoint. mock here for param in net.get_parameters(): param.init_data() diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index c4f6e0aa5b..e4ecfe696a 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -117,6 +117,7 @@ def test_loss_monitor_sink_mode(): """Test loss monitor sink mode.""" cb_params = _InternalCallbackParam() cb_params.cur_epoch_num = 4 + cb_params.epoch_num = 4 cb_params.cur_step_num = 2 cb_params.batch_num = 2 cb_params.net_outputs = Tensor(2.0) @@ -138,6 +139,7 @@ def test_loss_monitor_normal_mode(): run_context = RunContext(cb_params) loss_cb = LossMonitor(1) cb_params.cur_epoch_num = 4 + cb_params.epoch_num = 4 cb_params.cur_step_num = 1 cb_params.batch_num = 1 cb_params.net_outputs = Tensor(2.0)