Browse Source

!5447 Support manual convert to quantative network of resnet

Merge pull request !5447 from chenfei_mindspore/r0.6
pull/5447/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
d1b1a626c2
6 changed files with 365 additions and 7 deletions
  1. +31
    -0
      model_zoo/official/cv/mobilenetv2_quant/src/utils.py
  2. +3
    -3
      model_zoo/official/cv/mobilenetv2_quant/train.py
  3. +2
    -1
      model_zoo/official/cv/resnet50_quant/eval.py
  4. +325
    -0
      model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py
  5. +2
    -2
      model_zoo/official/cv/resnet50_quant/src/config.py
  6. +2
    -1
      model_zoo/official/cv/resnet50_quant/train.py

+ 31
- 0
model_zoo/official/cv/mobilenetv2_quant/src/utils.py View File

@@ -111,3 +111,34 @@ class CrossEntropyWithLabelSmooth(_Loss):
out_loss = self.ce(logit, one_hot_label) out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0) out_loss = self.mean(out_loss, 0)
return out_loss return out_loss

def _load_param_into_net(model, params_dict):
"""
load fp32 model parameters to quantization model.

Args:
model: quantization model
params_dict: f32 param

Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
continue
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')

+ 3
- 3
model_zoo/official/cv/mobilenetv2_quant/train.py View File

@@ -24,14 +24,14 @@ from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.quant import quant from mindspore.train.quant import quant
import mindspore.dataset.engine as de import mindspore.dataset.engine as de


from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.utils import Monitor, CrossEntropyWithLabelSmooth
from src.utils import Monitor, CrossEntropyWithLabelSmooth, _load_param_into_net
from src.config import config_ascend, config_ascend_quant from src.config import config_ascend, config_ascend_quant
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2


@@ -92,7 +92,7 @@ if __name__ == '__main__':
# load pre trained ckpt # load pre trained ckpt
if args_opt.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(network, param_dict)
_load_param_into_net(network, param_dict)


# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware: if config.quantization_aware:


+ 2
- 1
model_zoo/official/cv/resnet50_quant/eval.py View File

@@ -21,7 +21,8 @@ from src.config import quant_set, config_quant, config_noquant
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net from src.utils import _load_param_into_net
from models.resnet_quant import resnet50_quant
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50


from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model


+ 325
- 0
model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py View File

@@ -0,0 +1,325 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant

_ema_decay = 0.999
_symmetric = True
_fake = True
_per_channel = True


def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)


def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)


class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.

Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.

Returns:
Tensor, output tensor.

Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""

def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers)

def construct(self, x):
output = self.features(x)
return output


class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.

Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.

Returns:
Tensor, output tensor.

Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4

def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()

channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=1,
pad_mode='same', padding=0)

self.down_sample = False

if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None

if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=stride,
pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1,
stride=stride,
pad_mode='same',
padding=0)
self.add = nn.TensorAddQuant()
self.relu = P.ReLU()

def construct(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)

if self.down_sample:
identity = self.down_sample_layer(identity)

out = self.add(out, identity)
out = self.relu(out)

return out


class ResNet(nn.Cell):
"""
ResNet architecture.

Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.

Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""

def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet, self).__init__()

if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")

self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")

self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)

def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.

Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.

Returns:
SequentialCell, the output layer.

Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []

resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)

for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)

return nn.SequentialCell(layers)

def construct(self, x):
x = self.conv1(x)
c1 = self.maxpool(x)

c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)

out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
out = self.output_fake(out)
return out


def resnet50_quant(class_num=10):
"""
Get ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet50 neural network.

Examples:
>>> net = resnet50_quant(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)


def resnet101_quant(class_num=1001):
"""
Get ResNet101 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet101 neural network.

Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

+ 2
- 2
model_zoo/official/cv/resnet50_quant/src/config.py View File

@@ -31,7 +31,7 @@ config_noquant = ed({
"buffer_size": 1000, "buffer_size": 1000,
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"data_load_mode": "mindrecord",
"data_load_mode": "mindata",
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50, "keep_checkpoint_max": 50,
@@ -54,7 +54,7 @@ config_quant = ed({
"buffer_size": 1000, "buffer_size": 1000,
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"data_load_mode": "mindrecord",
"data_load_mode": "mindata",
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50, "keep_checkpoint_max": 50,


+ 2
- 1
model_zoo/official/cv/resnet50_quant/train.py View File

@@ -30,7 +30,8 @@ from mindspore.communication.management import init
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init


from models.resnet_quant import resnet50_quant
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import quant_set, config_quant, config_noquant from src.config import quant_set, config_quant, config_noquant


Loading…
Cancel
Save