From: @jiangzg001 Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejianr1.2
| @@ -16,14 +16,12 @@ | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import numpy as np | |||||
| import cv2 | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.nets import net_factory | from src.nets import net_factory | ||||
| from src.utils.eval_utils import BuildEvalNetwork, net_eval | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | ||||
| device_id=int(os.getenv('DEVICE_ID'))) | device_id=int(os.getenv('DEVICE_ID'))) | ||||
| @@ -45,169 +43,24 @@ def parse_args(): | |||||
| # model | # model | ||||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | ||||
| parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') | |||||
| parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') | parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') | ||||
| args_space, _ = parser.parse_known_args() | |||||
| return args_space | |||||
| args, _ = parser.parse_known_args() | |||||
| return args | |||||
| def cal_hist(a, b, n): | |||||
| k = (a >= 0) & (a < n) | |||||
| return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) | |||||
| def resize_long(img, long_size=513): | |||||
| h, w, _ = img.shape | |||||
| if h > w: | |||||
| new_h = long_size | |||||
| new_w = int(1.0 * long_size * w / h) | |||||
| else: | |||||
| new_w = long_size | |||||
| new_h = int(1.0 * long_size * h / w) | |||||
| imo = cv2.resize(img, (new_w, new_h)) | |||||
| return imo | |||||
| class BuildEvalNetwork(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(BuildEvalNetwork, self).__init__() | |||||
| self.network = network | |||||
| self.softmax = nn.Softmax(axis=1) | |||||
| def construct(self, input_data): | |||||
| output = self.network(input_data) | |||||
| output = self.softmax(output) | |||||
| return output | |||||
| def pre_process(args, img_, crop_size=513): | |||||
| # resize | |||||
| img_ = resize_long(img_, crop_size) | |||||
| resize_h, resize_w, _ = img_.shape | |||||
| # mean, std | |||||
| image_mean = np.array(args.image_mean) | |||||
| image_std = np.array(args.image_std) | |||||
| img_ = (img_ - image_mean) / image_std | |||||
| # pad to crop_size | |||||
| pad_h = crop_size - img_.shape[0] | |||||
| pad_w = crop_size - img_.shape[1] | |||||
| if pad_h > 0 or pad_w > 0: | |||||
| img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) | |||||
| # hwc to chw | |||||
| img_ = img_.transpose((2, 0, 1)) | |||||
| return img_, resize_h, resize_w | |||||
| def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): | |||||
| result_lst = [] | |||||
| batch_size = len(img_lst) | |||||
| batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) | |||||
| resize_hw = [] | |||||
| for l in range(batch_size): | |||||
| img_ = img_lst[l] | |||||
| img_, resize_h, resize_w = pre_process(args, img_, crop_size) | |||||
| batch_img[l] = img_ | |||||
| resize_hw.append([resize_h, resize_w]) | |||||
| batch_img = np.ascontiguousarray(batch_img) | |||||
| net_out = eval_net(Tensor(batch_img, mstype.float32)) | |||||
| net_out = net_out.asnumpy() | |||||
| if flip: | |||||
| batch_img = batch_img[:, :, :, ::-1] | |||||
| net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) | |||||
| net_out += net_out_flip.asnumpy()[:, :, :, ::-1] | |||||
| for bs in range(batch_size): | |||||
| probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) | |||||
| ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] | |||||
| probs_ = cv2.resize(probs_, (ori_w, ori_h)) | |||||
| result_lst.append(probs_) | |||||
| return result_lst | |||||
| def eval_batch_scales(args, eval_net, img_lst, scales, | |||||
| base_crop_size=513, flip=True): | |||||
| sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] | |||||
| probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) | |||||
| print(sizes_) | |||||
| for crop_size_ in sizes_[1:]: | |||||
| probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) | |||||
| for pl, _ in enumerate(probs_lst): | |||||
| probs_lst[pl] += probs_lst_tmp[pl] | |||||
| result_msk = [] | |||||
| for i in probs_lst: | |||||
| result_msk.append(i.argmax(axis=2)) | |||||
| return result_msk | |||||
| def net_eval(): | |||||
| if __name__ == '__main__': | |||||
| args = parse_args() | args = parse_args() | ||||
| # data list | |||||
| with open(args.data_lst) as f: | |||||
| img_lst = f.readlines() | |||||
| # network | # network | ||||
| if args.model == 'deeplab_v3_s16': | if args.model == 'deeplab_v3_s16': | ||||
| network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn) | |||||
| network = net_factory.nets_map[args.model](args.num_classes, 16) | |||||
| elif args.model == 'deeplab_v3_s8': | elif args.model == 'deeplab_v3_s8': | ||||
| network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn) | |||||
| network = net_factory.nets_map[args.model](args.num_classes, 8) | |||||
| else: | else: | ||||
| raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | ||||
| eval_net = BuildEvalNetwork(network) | |||||
| eval_net = BuildEvalNetwork(network, args.input_format) | |||||
| # load model | # load model | ||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(eval_net, param_dict) | load_param_into_net(eval_net, param_dict) | ||||
| eval_net.set_train(False) | eval_net.set_train(False) | ||||
| # evaluate | |||||
| hist = np.zeros((args.num_classes, args.num_classes)) | |||||
| batch_img_lst = [] | |||||
| batch_msk_lst = [] | |||||
| bi = 0 | |||||
| image_num = 0 | |||||
| for i, line in enumerate(img_lst): | |||||
| img_path, msk_path = line.strip().split(' ') | |||||
| img_path = os.path.join(args.data_root, img_path) | |||||
| msk_path = os.path.join(args.data_root, msk_path) | |||||
| img_ = cv2.imread(img_path) | |||||
| msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) | |||||
| batch_img_lst.append(img_) | |||||
| batch_msk_lst.append(msk_) | |||||
| bi += 1 | |||||
| if bi == args.batch_size: | |||||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||||
| base_crop_size=args.crop_size, flip=args.flip) | |||||
| for mi in range(args.batch_size): | |||||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||||
| bi = 0 | |||||
| batch_img_lst = [] | |||||
| batch_msk_lst = [] | |||||
| print('processed {} images'.format(i+1)) | |||||
| image_num = i | |||||
| if bi > 0: | |||||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||||
| base_crop_size=args.crop_size, flip=args.flip) | |||||
| for mi in range(bi): | |||||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||||
| print('processed {} images'.format(image_num + 1)) | |||||
| print(hist) | |||||
| iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | |||||
| print('per-class IoU', iu) | |||||
| print('mean IoU', np.nanmean(iu)) | |||||
| if __name__ == '__main__': | |||||
| net_eval() | |||||
| net_eval(args, eval_net) | |||||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||||
| fi | fi | ||||
| mkdir -p ${train_path} | mkdir -p ${train_path} | ||||
| mkdir ${train_path}/ckpt | mkdir ${train_path}/ckpt | ||||
| ''' | |||||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||||
| Otherwise, it can be empty string with "". | |||||
| ''' | |||||
| for((i=0;i<=$RANK_SIZE-1;i++)); | for((i=0;i<=$RANK_SIZE-1;i++)); | ||||
| do | do | ||||
| export RANK_ID=${i} | export RANK_ID=${i} | ||||
| @@ -37,6 +40,9 @@ do | |||||
| cd ${train_path}/device${DEVICE_ID} || exit | cd ${train_path}/device${DEVICE_ID} || exit | ||||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | ||||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | --data_file=/PATH/TO/MINDRECORD_NAME \ | ||||
| --data_root=/PATH/TO/DATA \ | |||||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||||
| --scales=1.0 \ | |||||
| --train_epochs=300 \ | --train_epochs=300 \ | ||||
| --batch_size=32 \ | --batch_size=32 \ | ||||
| --crop_size=513 \ | --crop_size=513 \ | ||||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||||
| fi | fi | ||||
| mkdir -p ${train_path} | mkdir -p ${train_path} | ||||
| mkdir ${train_path}/ckpt | mkdir ${train_path}/ckpt | ||||
| ''' | |||||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||||
| Otherwise, it can be empty string with "". | |||||
| ''' | |||||
| for((i=0;i<=$RANK_SIZE-1;i++)); | for((i=0;i<=$RANK_SIZE-1;i++)); | ||||
| do | do | ||||
| export RANK_ID=${i} | export RANK_ID=${i} | ||||
| @@ -37,6 +40,9 @@ do | |||||
| cd ${train_path}/device${DEVICE_ID} || exit | cd ${train_path}/device${DEVICE_ID} || exit | ||||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | ||||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | --data_file=/PATH/TO/MINDRECORD_NAME \ | ||||
| --data_root=/PATH/TO/DATA \ | |||||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||||
| --scales=1.0 \ | |||||
| --train_epochs=800 \ | --train_epochs=800 \ | ||||
| --batch_size=16 \ | --batch_size=16 \ | ||||
| --crop_size=513 \ | --crop_size=513 \ | ||||
| @@ -51,5 +57,6 @@ do | |||||
| --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ | --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ | ||||
| --is_distributed \ | --is_distributed \ | ||||
| --save_steps=820 \ | --save_steps=820 \ | ||||
| --scales=1.0 \ | |||||
| --keep_checkpoint_max=200 >log 2>&1 & | --keep_checkpoint_max=200 >log 2>&1 & | ||||
| done | done | ||||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||||
| fi | fi | ||||
| mkdir -p ${train_path} | mkdir -p ${train_path} | ||||
| mkdir ${train_path}/ckpt | mkdir ${train_path}/ckpt | ||||
| ''' | |||||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||||
| Otherwise, it can be empty string with "". | |||||
| ''' | |||||
| for((i=0;i<=$RANK_SIZE-1;i++)); | for((i=0;i<=$RANK_SIZE-1;i++)); | ||||
| do | do | ||||
| export RANK_ID=${i} | export RANK_ID=${i} | ||||
| @@ -37,6 +40,9 @@ do | |||||
| cd ${train_path}/device${DEVICE_ID} || exit | cd ${train_path}/device${DEVICE_ID} || exit | ||||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | ||||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | --data_file=/PATH/TO/MINDRECORD_NAME \ | ||||
| --data_root=/PATH/TO/DATA \ | |||||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||||
| --scales=1.0 \ | |||||
| --train_epochs=300 \ | --train_epochs=300 \ | ||||
| --batch_size=16 \ | --batch_size=16 \ | ||||
| --crop_size=513 \ | --crop_size=513 \ | ||||
| @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||||
| --num_classes=21 \ | --num_classes=21 \ | ||||
| --model=deeplab_v3_s16 \ | --model=deeplab_v3_s16 \ | ||||
| --scales=1.0 \ | --scales=1.0 \ | ||||
| --freeze_bn \ | |||||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | ||||
| @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||||
| --num_classes=21 \ | --num_classes=21 \ | ||||
| --model=deeplab_v3_s8 \ | --model=deeplab_v3_s8 \ | ||||
| --scales=1.0 \ | --scales=1.0 \ | ||||
| --freeze_bn \ | |||||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | ||||
| @@ -36,6 +36,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||||
| --scales=1.0 \ | --scales=1.0 \ | ||||
| --scales=1.25 \ | --scales=1.25 \ | ||||
| --scales=1.75 \ | --scales=1.75 \ | ||||
| --freeze_bn \ | |||||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | ||||
| @@ -37,6 +37,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||||
| --scales=1.25 \ | --scales=1.25 \ | ||||
| --scales=1.75 \ | --scales=1.75 \ | ||||
| --flip \ | --flip \ | ||||
| --freeze_bn \ | |||||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | ||||
| @@ -26,9 +26,15 @@ mkdir -p ${train_path} | |||||
| mkdir ${train_path}/device${DEVICE_ID} | mkdir ${train_path}/device${DEVICE_ID} | ||||
| mkdir ${train_path}/ckpt | mkdir ${train_path}/ckpt | ||||
| cd ${train_path}/device${DEVICE_ID} || exit | cd ${train_path}/device${DEVICE_ID} || exit | ||||
| ''' | |||||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||||
| Otherwise, it can be empty string with "". | |||||
| ''' | |||||
| python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ | python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ | ||||
| --train_dir=${train_path}/ckpt \ | --train_dir=${train_path}/ckpt \ | ||||
| --data_root=/PATH/TO/DATA \ | |||||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||||
| --scales=1.0 \ | |||||
| --train_epochs=200 \ | --train_epochs=200 \ | ||||
| --batch_size=32 \ | --batch_size=32 \ | ||||
| --crop_size=513 \ | --crop_size=513 \ | ||||
| @@ -27,47 +27,41 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1): | |||||
| class Resnet(nn.Cell): | class Resnet(nn.Cell): | ||||
| def __init__(self, block, block_num, output_stride, use_batch_statistics=True): | |||||
| def __init__(self, block, block_num, output_stride): | |||||
| super(Resnet, self).__init__() | super(Resnet, self).__init__() | ||||
| self.inplanes = 64 | self.inplanes = 64 | ||||
| self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3, | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3, | ||||
| weight_init='xavier_uniform') | weight_init='xavier_uniform') | ||||
| self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics) | |||||
| self.bn1 = nn.BatchNorm2d(self.inplanes) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | ||||
| self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics) | |||||
| self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics) | |||||
| self.layer1 = self._make_layer(block, 64, block_num[0]) | |||||
| self.layer2 = self._make_layer(block, 128, block_num[1], stride=2) | |||||
| if output_stride == 16: | if output_stride == 16: | ||||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=2, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4], | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=2) | |||||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4]) | |||||
| elif output_stride == 8: | elif output_stride == 8: | ||||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4], | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2) | |||||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4]) | |||||
| def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True): | |||||
| def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None): | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | if stride != 1 or self.inplanes != planes * block.expansion: | ||||
| downsample = nn.SequentialCell([ | downsample = nn.SequentialCell([ | ||||
| conv1x1(self.inplanes, planes * block.expansion, stride), | conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
| nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics) | |||||
| nn.BatchNorm2d(planes * block.expansion) | |||||
| ]) | ]) | ||||
| if grids is None: | if grids is None: | ||||
| grids = [1] * blocks | grids = [1] * blocks | ||||
| layers = [ | layers = [ | ||||
| block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0], | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0]) | |||||
| ] | ] | ||||
| self.inplanes = planes * block.expansion | self.inplanes = planes * block.expansion | ||||
| for i in range(1, blocks): | for i in range(1, blocks): | ||||
| layers.append( | layers.append( | ||||
| block(self.inplanes, planes, dilation=base_dilation * grids[i], | |||||
| use_batch_statistics=use_batch_statistics)) | |||||
| block(self.inplanes, planes, dilation=base_dilation * grids[i])) | |||||
| return nn.SequentialCell(layers) | return nn.SequentialCell(layers) | ||||
| @@ -88,16 +82,16 @@ class Resnet(nn.Cell): | |||||
| class Bottleneck(nn.Cell): | class Bottleneck(nn.Cell): | ||||
| expansion = 4 | expansion = 4 | ||||
| def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True): | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): | |||||
| super(Bottleneck, self).__init__() | super(Bottleneck, self).__init__() | ||||
| self.conv1 = conv1x1(inplanes, planes) | self.conv1 = conv1x1(inplanes, planes) | ||||
| self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) | |||||
| self.bn1 = nn.BatchNorm2d(planes) | |||||
| self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) | ||||
| self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) | |||||
| self.bn2 = nn.BatchNorm2d(planes) | |||||
| self.conv3 = conv1x1(planes, planes * self.expansion) | self.conv3 = conv1x1(planes, planes * self.expansion) | ||||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics) | |||||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.downsample = downsample | self.downsample = downsample | ||||
| @@ -127,19 +121,17 @@ class Bottleneck(nn.Cell): | |||||
| class ASPP(nn.Cell): | class ASPP(nn.Cell): | ||||
| def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21, | |||||
| use_batch_statistics=True): | |||||
| def __init__(self, atrous_rates, in_channels=2048, num_classes=21): | |||||
| super(ASPP, self).__init__() | super(ASPP, self).__init__() | ||||
| self.phase = phase | |||||
| out_channels = 256 | out_channels = 256 | ||||
| self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics) | |||||
| self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics) | |||||
| self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics) | |||||
| self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics) | |||||
| self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics) | |||||
| self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0]) | |||||
| self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1]) | |||||
| self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2]) | |||||
| self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3]) | |||||
| self.aspp_pooling = ASPPPooling(in_channels, out_channels) | |||||
| self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1, | self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1, | ||||
| weight_init='xavier_uniform') | weight_init='xavier_uniform') | ||||
| self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| self.bn1 = nn.BatchNorm2d(out_channels) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True) | self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True) | ||||
| self.concat = P.Concat(axis=1) | self.concat = P.Concat(axis=1) | ||||
| @@ -160,18 +152,18 @@ class ASPP(nn.Cell): | |||||
| x = self.conv1(x) | x = self.conv1(x) | ||||
| x = self.bn1(x) | x = self.bn1(x) | ||||
| x = self.relu(x) | x = self.relu(x) | ||||
| if self.phase == 'train': | |||||
| if self.training: | |||||
| x = self.drop(x) | x = self.drop(x) | ||||
| x = self.conv2(x) | x = self.conv2(x) | ||||
| return x | return x | ||||
| class ASPPPooling(nn.Cell): | class ASPPPooling(nn.Cell): | ||||
| def __init__(self, in_channels, out_channels, use_batch_statistics=True): | |||||
| def __init__(self, in_channels, out_channels): | |||||
| super(ASPPPooling, self).__init__() | super(ASPPPooling, self).__init__() | ||||
| self.conv = nn.SequentialCell([ | self.conv = nn.SequentialCell([ | ||||
| nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'), | nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'), | ||||
| nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics), | |||||
| nn.BatchNorm2d(out_channels), | |||||
| nn.ReLU() | nn.ReLU() | ||||
| ]) | ]) | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| @@ -185,14 +177,14 @@ class ASPPPooling(nn.Cell): | |||||
| class ASPPConv(nn.Cell): | class ASPPConv(nn.Cell): | ||||
| def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True): | |||||
| def __init__(self, in_channels, out_channels, atrous_rate=1): | |||||
| super(ASPPConv, self).__init__() | super(ASPPConv, self).__init__() | ||||
| if atrous_rate == 1: | if atrous_rate == 1: | ||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform') | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform') | ||||
| else: | else: | ||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate, | conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate, | ||||
| dilation=atrous_rate, weight_init='xavier_uniform') | dilation=atrous_rate, weight_init='xavier_uniform') | ||||
| bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| bn = nn.BatchNorm2d(out_channels) | |||||
| relu = nn.ReLU() | relu = nn.ReLU() | ||||
| self.aspp_conv = nn.SequentialCell([conv, bn, relu]) | self.aspp_conv = nn.SequentialCell([conv, bn, relu]) | ||||
| @@ -202,13 +194,10 @@ class ASPPConv(nn.Cell): | |||||
| class DeepLabV3(nn.Cell): | class DeepLabV3(nn.Cell): | ||||
| def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False): | |||||
| def __init__(self, num_classes=21, output_stride=16): | |||||
| super(DeepLabV3, self).__init__() | super(DeepLabV3, self).__init__() | ||||
| use_batch_statistics = not freeze_bn | |||||
| self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride) | |||||
| self.aspp = ASPP([1, 6, 12, 18], 2048, num_classes) | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -0,0 +1,85 @@ | |||||
| import os | |||||
| import stat | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore import log as logger | |||||
| from mindspore import save_checkpoint | |||||
| from src.utils.eval_utils import BuildEvalNetwork, net_eval | |||||
| class EvalCallBack(Callback): | |||||
| """ | |||||
| Evaluation callback when training. | |||||
| Args: | |||||
| eval_function (function): evaluation function. | |||||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||||
| interval (int): run evaluation interval, default is 1. | |||||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||||
| Returns: | |||||
| None | |||||
| Examples: | |||||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||||
| """ | |||||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||||
| super(EvalCallBack, self).__init__() | |||||
| self.eval_param_dict = eval_param_dict | |||||
| self.eval_function = eval_function | |||||
| self.eval_start_epoch = eval_start_epoch | |||||
| if interval < 1: | |||||
| raise ValueError("interval should >= 1.") | |||||
| self.interval = interval | |||||
| self.save_best_ckpt = save_best_ckpt | |||||
| self.best_res = 0 | |||||
| self.best_epoch = 0 | |||||
| if not os.path.isdir(ckpt_directory): | |||||
| os.makedirs(ckpt_directory) | |||||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||||
| self.metrics_name = metrics_name | |||||
| def remove_ckpoint_file(self, file_name): | |||||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||||
| try: | |||||
| os.chmod(file_name, stat.S_IWRITE) | |||||
| os.remove(file_name) | |||||
| except OSError: | |||||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||||
| except ValueError: | |||||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||||
| def epoch_end(self, run_context): | |||||
| """Callback when epoch end.""" | |||||
| cb_params = run_context.original_args() | |||||
| cur_epoch = cb_params.cur_epoch_num | |||||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||||
| res = self.eval_function(self.eval_param_dict) | |||||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||||
| if res >= self.best_res: | |||||
| self.best_res = res | |||||
| self.best_epoch = cur_epoch | |||||
| print("update best result: {}".format(res), flush=True) | |||||
| if self.save_best_ckpt: | |||||
| if os.path.exists(self.bast_ckpt_path): | |||||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||||
| def end(self, run_context): | |||||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||||
| self.best_res, | |||||
| self.best_epoch), flush=True) | |||||
| def apply_eval(eval_param_dict): | |||||
| network = eval_param_dict["net"] | |||||
| network = BuildEvalNetwork(network) | |||||
| network.set_train(False) | |||||
| args = eval_param_dict["args"] | |||||
| args.data_lst = eval_param_dict["dataset"] | |||||
| eval_result = net_eval(args, network) | |||||
| return eval_result | |||||
| @@ -0,0 +1,152 @@ | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops as ops | |||||
| def cal_hist(a, b, n): | |||||
| k = (a >= 0) & (a < n) | |||||
| return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) | |||||
| def resize_long(img, long_size=513): | |||||
| h, w, _ = img.shape | |||||
| if h > w: | |||||
| new_h = long_size | |||||
| new_w = int(1.0 * long_size * w / h) | |||||
| else: | |||||
| new_w = long_size | |||||
| new_h = int(1.0 * long_size * h / w) | |||||
| imo = cv2.resize(img, (new_w, new_h)) | |||||
| return imo | |||||
| class BuildEvalNetwork(nn.Cell): | |||||
| def __init__(self, network, input_format="NCHW"): | |||||
| super(BuildEvalNetwork, self).__init__() | |||||
| self.network = network | |||||
| self.softmax = nn.Softmax(axis=1) | |||||
| self.transpose = ops.Transpose() | |||||
| self.format = input_format | |||||
| def construct(self, input_data): | |||||
| if self.format == "NHWC": | |||||
| input_data = self.transpose(input_data, (0, 3, 1, 2)) | |||||
| output = self.network(input_data) | |||||
| output = self.softmax(output) | |||||
| return output | |||||
| def pre_process(args, img_, crop_size=513): | |||||
| # resize | |||||
| img_ = resize_long(img_, crop_size) | |||||
| resize_h, resize_w, _ = img_.shape | |||||
| # mean, std | |||||
| image_mean = np.array(args.image_mean) | |||||
| image_std = np.array(args.image_std) | |||||
| img_ = (img_ - image_mean) / image_std | |||||
| # pad to crop_size | |||||
| pad_h = crop_size - img_.shape[0] | |||||
| pad_w = crop_size - img_.shape[1] | |||||
| if pad_h > 0 or pad_w > 0: | |||||
| img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) | |||||
| # hwc to chw | |||||
| img_ = img_.transpose((2, 0, 1)) | |||||
| return img_, resize_h, resize_w | |||||
| def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): | |||||
| result_lst = [] | |||||
| batch_size = len(img_lst) | |||||
| batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) | |||||
| resize_hw = [] | |||||
| for l in range(batch_size): | |||||
| img_ = img_lst[l] | |||||
| img_, resize_h, resize_w = pre_process(args, img_, crop_size) | |||||
| batch_img[l] = img_ | |||||
| resize_hw.append([resize_h, resize_w]) | |||||
| batch_img = np.ascontiguousarray(batch_img) | |||||
| net_out = eval_net(Tensor(batch_img, mstype.float32)) | |||||
| net_out = net_out.asnumpy() | |||||
| if flip: | |||||
| batch_img = batch_img[:, :, :, ::-1] | |||||
| net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) | |||||
| net_out += net_out_flip.asnumpy()[:, :, :, ::-1] | |||||
| for bs in range(batch_size): | |||||
| probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) | |||||
| ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] | |||||
| probs_ = cv2.resize(probs_, (ori_w, ori_h)) | |||||
| result_lst.append(probs_) | |||||
| return result_lst | |||||
| def eval_batch_scales(args, eval_net, img_lst, scales, | |||||
| base_crop_size=513, flip=True): | |||||
| sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] | |||||
| probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) | |||||
| print(sizes_) | |||||
| for crop_size_ in sizes_[1:]: | |||||
| probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) | |||||
| for pl, _ in enumerate(probs_lst): | |||||
| probs_lst[pl] += probs_lst_tmp[pl] | |||||
| result_msk = [] | |||||
| for i in probs_lst: | |||||
| result_msk.append(i.argmax(axis=2)) | |||||
| return result_msk | |||||
| def net_eval(args, eval_net): | |||||
| # data list | |||||
| with open(args.data_lst) as f: | |||||
| img_lst = f.readlines() | |||||
| # evaluate | |||||
| hist = np.zeros((args.num_classes, args.num_classes)) | |||||
| batch_img_lst = [] | |||||
| batch_msk_lst = [] | |||||
| bi = 0 | |||||
| image_num = 0 | |||||
| for i, line in enumerate(img_lst): | |||||
| img_path, msk_path = line.strip().split(' ') | |||||
| img_path = os.path.join(args.data_root, img_path) | |||||
| msk_path = os.path.join(args.data_root, msk_path) | |||||
| img_ = cv2.imread(img_path) | |||||
| msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) | |||||
| batch_img_lst.append(img_) | |||||
| batch_msk_lst.append(msk_) | |||||
| bi += 1 | |||||
| if bi == args.batch_size: | |||||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||||
| base_crop_size=args.crop_size, flip=args.flip) | |||||
| for mi in range(args.batch_size): | |||||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||||
| bi = 0 | |||||
| batch_img_lst = [] | |||||
| batch_msk_lst = [] | |||||
| print('processed {} images'.format(i + 1)) | |||||
| image_num = i | |||||
| if bi > 0: | |||||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||||
| base_crop_size=args.crop_size, flip=args.flip) | |||||
| for mi in range(bi): | |||||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||||
| print('processed {} images'.format(image_num + 1)) | |||||
| print(hist) | |||||
| iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | |||||
| mean_iou = np.nanmean(iu) | |||||
| print('per-class IoU', iu) | |||||
| print('mean IoU', mean_iou) | |||||
| return mean_iou | |||||
| @@ -31,7 +31,8 @@ from src.data import dataset as data_generator | |||||
| from src.loss import loss | from src.loss import loss | ||||
| from src.nets import net_factory | from src.nets import net_factory | ||||
| from src.utils import learning_rates | from src.utils import learning_rates | ||||
| from src.utils.eval_utils import BuildEvalNetwork | |||||
| from src.utils.eval_callback import EvalCallBack, apply_eval | |||||
| set_seed(1) | set_seed(1) | ||||
| @@ -72,7 +73,6 @@ def parse_args(): | |||||
| # model | # model | ||||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | ||||
| parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') | |||||
| parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') | parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') | ||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | ||||
| help="Filter the last weight parameters, default is False.") | help="Filter the last weight parameters, default is False.") | ||||
| @@ -86,6 +86,22 @@ def parse_args(): | |||||
| parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving') | parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving') | ||||
| parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving') | parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving') | ||||
| # validate | |||||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||||
| help="Run evaluation when training, default is False.") | |||||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||||
| help="Save best checkpoint when run_eval is True, default is True.") | |||||
| parser.add_argument("--eval_start_epoch", type=int, default=200, | |||||
| help="Evaluation start epoch when run_eval is True, default is 200.") | |||||
| parser.add_argument("--eval_interval", type=int, default=1, | |||||
| help="Evaluation interval when run_eval is True, default is 1.") | |||||
| parser.add_argument('--ann_file', type=str, default='', help='path to annotation') | |||||
| parser.add_argument('--data_root', type=str, default='', help='root path of val data') | |||||
| parser.add_argument('--val_data', type=str, default='', help='list of val data') | |||||
| parser.add_argument('--scales', type=float, action='append', help='scales of evaluation') | |||||
| parser.add_argument('--flip', action='store_true', help='perform left-right flip') | |||||
| parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW", | |||||
| help="NCHW or NHWC") | |||||
| args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
| return args | return args | ||||
| @@ -126,9 +142,9 @@ def train(): | |||||
| # network | # network | ||||
| if args.model == 'deeplab_v3_s16': | if args.model == 'deeplab_v3_s16': | ||||
| network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn) | |||||
| network = net_factory.nets_map[args.model](args.num_classes, 16) | |||||
| elif args.model == 'deeplab_v3_s8': | elif args.model == 'deeplab_v3_s8': | ||||
| network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn) | |||||
| network = net_factory.nets_map[args.model](args.num_classes, 8) | |||||
| else: | else: | ||||
| raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | ||||
| @@ -169,6 +185,7 @@ def train(): | |||||
| # loss scale | # loss scale | ||||
| manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | ||||
| amp_level = "O0" if args.device_target == "CPU" else "O3" | amp_level = "O0" if args.device_target == "CPU" else "O3" | ||||
| train_net.set_train(True) | |||||
| model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale) | model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale) | ||||
| # callback for saving ckpts | # callback for saving ckpts | ||||
| @@ -182,6 +199,16 @@ def train(): | |||||
| ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) | ||||
| cbs.append(ckpoint_cb) | cbs.append(ckpoint_cb) | ||||
| if args.run_eval and args.rank == 0: | |||||
| network_eval = BuildEvalNetwork(network, args.input_format) | |||||
| eval_dataset = args.val_data | |||||
| save_ckpt_path = args.train_dir | |||||
| eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "args": args} | |||||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args.eval_interval, | |||||
| eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True, | |||||
| ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt", | |||||
| metrics_name="mIou") | |||||
| cbs.append(eval_cb) | |||||
| model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU")) | model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU")) | ||||