From: @jiangzg001 Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejianr1.2
| @@ -16,14 +16,12 @@ | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| import cv2 | |||
| from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.nets import net_factory | |||
| from src.utils.eval_utils import BuildEvalNetwork, net_eval | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | |||
| device_id=int(os.getenv('DEVICE_ID'))) | |||
| @@ -45,169 +43,24 @@ def parse_args(): | |||
| # model | |||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | |||
| parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') | |||
| parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') | |||
| args_space, _ = parser.parse_known_args() | |||
| return args_space | |||
| args, _ = parser.parse_known_args() | |||
| return args | |||
| def cal_hist(a, b, n): | |||
| k = (a >= 0) & (a < n) | |||
| return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) | |||
| def resize_long(img, long_size=513): | |||
| h, w, _ = img.shape | |||
| if h > w: | |||
| new_h = long_size | |||
| new_w = int(1.0 * long_size * w / h) | |||
| else: | |||
| new_w = long_size | |||
| new_h = int(1.0 * long_size * h / w) | |||
| imo = cv2.resize(img, (new_w, new_h)) | |||
| return imo | |||
| class BuildEvalNetwork(nn.Cell): | |||
| def __init__(self, network): | |||
| super(BuildEvalNetwork, self).__init__() | |||
| self.network = network | |||
| self.softmax = nn.Softmax(axis=1) | |||
| def construct(self, input_data): | |||
| output = self.network(input_data) | |||
| output = self.softmax(output) | |||
| return output | |||
| def pre_process(args, img_, crop_size=513): | |||
| # resize | |||
| img_ = resize_long(img_, crop_size) | |||
| resize_h, resize_w, _ = img_.shape | |||
| # mean, std | |||
| image_mean = np.array(args.image_mean) | |||
| image_std = np.array(args.image_std) | |||
| img_ = (img_ - image_mean) / image_std | |||
| # pad to crop_size | |||
| pad_h = crop_size - img_.shape[0] | |||
| pad_w = crop_size - img_.shape[1] | |||
| if pad_h > 0 or pad_w > 0: | |||
| img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) | |||
| # hwc to chw | |||
| img_ = img_.transpose((2, 0, 1)) | |||
| return img_, resize_h, resize_w | |||
| def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): | |||
| result_lst = [] | |||
| batch_size = len(img_lst) | |||
| batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) | |||
| resize_hw = [] | |||
| for l in range(batch_size): | |||
| img_ = img_lst[l] | |||
| img_, resize_h, resize_w = pre_process(args, img_, crop_size) | |||
| batch_img[l] = img_ | |||
| resize_hw.append([resize_h, resize_w]) | |||
| batch_img = np.ascontiguousarray(batch_img) | |||
| net_out = eval_net(Tensor(batch_img, mstype.float32)) | |||
| net_out = net_out.asnumpy() | |||
| if flip: | |||
| batch_img = batch_img[:, :, :, ::-1] | |||
| net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) | |||
| net_out += net_out_flip.asnumpy()[:, :, :, ::-1] | |||
| for bs in range(batch_size): | |||
| probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) | |||
| ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] | |||
| probs_ = cv2.resize(probs_, (ori_w, ori_h)) | |||
| result_lst.append(probs_) | |||
| return result_lst | |||
| def eval_batch_scales(args, eval_net, img_lst, scales, | |||
| base_crop_size=513, flip=True): | |||
| sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] | |||
| probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) | |||
| print(sizes_) | |||
| for crop_size_ in sizes_[1:]: | |||
| probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) | |||
| for pl, _ in enumerate(probs_lst): | |||
| probs_lst[pl] += probs_lst_tmp[pl] | |||
| result_msk = [] | |||
| for i in probs_lst: | |||
| result_msk.append(i.argmax(axis=2)) | |||
| return result_msk | |||
| def net_eval(): | |||
| if __name__ == '__main__': | |||
| args = parse_args() | |||
| # data list | |||
| with open(args.data_lst) as f: | |||
| img_lst = f.readlines() | |||
| # network | |||
| if args.model == 'deeplab_v3_s16': | |||
| network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn) | |||
| network = net_factory.nets_map[args.model](args.num_classes, 16) | |||
| elif args.model == 'deeplab_v3_s8': | |||
| network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn) | |||
| network = net_factory.nets_map[args.model](args.num_classes, 8) | |||
| else: | |||
| raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | |||
| eval_net = BuildEvalNetwork(network) | |||
| eval_net = BuildEvalNetwork(network, args.input_format) | |||
| # load model | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(eval_net, param_dict) | |||
| eval_net.set_train(False) | |||
| # evaluate | |||
| hist = np.zeros((args.num_classes, args.num_classes)) | |||
| batch_img_lst = [] | |||
| batch_msk_lst = [] | |||
| bi = 0 | |||
| image_num = 0 | |||
| for i, line in enumerate(img_lst): | |||
| img_path, msk_path = line.strip().split(' ') | |||
| img_path = os.path.join(args.data_root, img_path) | |||
| msk_path = os.path.join(args.data_root, msk_path) | |||
| img_ = cv2.imread(img_path) | |||
| msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) | |||
| batch_img_lst.append(img_) | |||
| batch_msk_lst.append(msk_) | |||
| bi += 1 | |||
| if bi == args.batch_size: | |||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||
| base_crop_size=args.crop_size, flip=args.flip) | |||
| for mi in range(args.batch_size): | |||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||
| bi = 0 | |||
| batch_img_lst = [] | |||
| batch_msk_lst = [] | |||
| print('processed {} images'.format(i+1)) | |||
| image_num = i | |||
| if bi > 0: | |||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||
| base_crop_size=args.crop_size, flip=args.flip) | |||
| for mi in range(bi): | |||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||
| print('processed {} images'.format(image_num + 1)) | |||
| print(hist) | |||
| iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | |||
| print('per-class IoU', iu) | |||
| print('mean IoU', np.nanmean(iu)) | |||
| if __name__ == '__main__': | |||
| net_eval() | |||
| net_eval(args, eval_net) | |||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||
| fi | |||
| mkdir -p ${train_path} | |||
| mkdir ${train_path}/ckpt | |||
| ''' | |||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||
| Otherwise, it can be empty string with "". | |||
| ''' | |||
| for((i=0;i<=$RANK_SIZE-1;i++)); | |||
| do | |||
| export RANK_ID=${i} | |||
| @@ -37,6 +40,9 @@ do | |||
| cd ${train_path}/device${DEVICE_ID} || exit | |||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | |||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | |||
| --data_root=/PATH/TO/DATA \ | |||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||
| --scales=1.0 \ | |||
| --train_epochs=300 \ | |||
| --batch_size=32 \ | |||
| --crop_size=513 \ | |||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||
| fi | |||
| mkdir -p ${train_path} | |||
| mkdir ${train_path}/ckpt | |||
| ''' | |||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||
| Otherwise, it can be empty string with "". | |||
| ''' | |||
| for((i=0;i<=$RANK_SIZE-1;i++)); | |||
| do | |||
| export RANK_ID=${i} | |||
| @@ -37,6 +40,9 @@ do | |||
| cd ${train_path}/device${DEVICE_ID} || exit | |||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | |||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | |||
| --data_root=/PATH/TO/DATA \ | |||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||
| --scales=1.0 \ | |||
| --train_epochs=800 \ | |||
| --batch_size=16 \ | |||
| --crop_size=513 \ | |||
| @@ -51,5 +57,6 @@ do | |||
| --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ | |||
| --is_distributed \ | |||
| --save_steps=820 \ | |||
| --scales=1.0 \ | |||
| --keep_checkpoint_max=200 >log 2>&1 & | |||
| done | |||
| @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then | |||
| fi | |||
| mkdir -p ${train_path} | |||
| mkdir ${train_path}/ckpt | |||
| ''' | |||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||
| Otherwise, it can be empty string with "". | |||
| ''' | |||
| for((i=0;i<=$RANK_SIZE-1;i++)); | |||
| do | |||
| export RANK_ID=${i} | |||
| @@ -37,6 +40,9 @@ do | |||
| cd ${train_path}/device${DEVICE_ID} || exit | |||
| python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ | |||
| --data_file=/PATH/TO/MINDRECORD_NAME \ | |||
| --data_root=/PATH/TO/DATA \ | |||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||
| --scales=1.0 \ | |||
| --train_epochs=300 \ | |||
| --batch_size=16 \ | |||
| --crop_size=513 \ | |||
| @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||
| --num_classes=21 \ | |||
| --model=deeplab_v3_s16 \ | |||
| --scales=1.0 \ | |||
| --freeze_bn \ | |||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | |||
| @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||
| --num_classes=21 \ | |||
| --model=deeplab_v3_s8 \ | |||
| --scales=1.0 \ | |||
| --freeze_bn \ | |||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | |||
| @@ -36,6 +36,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||
| --scales=1.0 \ | |||
| --scales=1.25 \ | |||
| --scales=1.75 \ | |||
| --freeze_bn \ | |||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | |||
| @@ -37,6 +37,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ | |||
| --scales=1.25 \ | |||
| --scales=1.75 \ | |||
| --flip \ | |||
| --freeze_bn \ | |||
| --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & | |||
| @@ -26,9 +26,15 @@ mkdir -p ${train_path} | |||
| mkdir ${train_path}/device${DEVICE_ID} | |||
| mkdir ${train_path}/ckpt | |||
| cd ${train_path}/device${DEVICE_ID} || exit | |||
| ''' | |||
| If turn on the verification function while training, need to set `data_root` and `data_lst`. | |||
| Otherwise, it can be empty string with "". | |||
| ''' | |||
| python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ | |||
| --train_dir=${train_path}/ckpt \ | |||
| --data_root=/PATH/TO/DATA \ | |||
| --val_data=/PATH/TO/DATA_lst.txt \ | |||
| --scales=1.0 \ | |||
| --train_epochs=200 \ | |||
| --batch_size=32 \ | |||
| --crop_size=513 \ | |||
| @@ -27,47 +27,41 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1): | |||
| class Resnet(nn.Cell): | |||
| def __init__(self, block, block_num, output_stride, use_batch_statistics=True): | |||
| def __init__(self, block, block_num, output_stride): | |||
| super(Resnet, self).__init__() | |||
| self.inplanes = 64 | |||
| self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3, | |||
| weight_init='xavier_uniform') | |||
| self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics) | |||
| self.bn1 = nn.BatchNorm2d(self.inplanes) | |||
| self.relu = nn.ReLU() | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | |||
| self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics) | |||
| self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics) | |||
| self.layer1 = self._make_layer(block, 64, block_num[0]) | |||
| self.layer2 = self._make_layer(block, 128, block_num[1], stride=2) | |||
| if output_stride == 16: | |||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=2, | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4], | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=2) | |||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4]) | |||
| elif output_stride == 8: | |||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2, | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4], | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2) | |||
| self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4]) | |||
| def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True): | |||
| def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None): | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.SequentialCell([ | |||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||
| nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics) | |||
| nn.BatchNorm2d(planes * block.expansion) | |||
| ]) | |||
| if grids is None: | |||
| grids = [1] * blocks | |||
| layers = [ | |||
| block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0], | |||
| use_batch_statistics=use_batch_statistics) | |||
| block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0]) | |||
| ] | |||
| self.inplanes = planes * block.expansion | |||
| for i in range(1, blocks): | |||
| layers.append( | |||
| block(self.inplanes, planes, dilation=base_dilation * grids[i], | |||
| use_batch_statistics=use_batch_statistics)) | |||
| block(self.inplanes, planes, dilation=base_dilation * grids[i])) | |||
| return nn.SequentialCell(layers) | |||
| @@ -88,16 +82,16 @@ class Resnet(nn.Cell): | |||
| class Bottleneck(nn.Cell): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True): | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): | |||
| super(Bottleneck, self).__init__() | |||
| self.conv1 = conv1x1(inplanes, planes) | |||
| self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) | |||
| self.bn1 = nn.BatchNorm2d(planes) | |||
| self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) | |||
| self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) | |||
| self.bn2 = nn.BatchNorm2d(planes) | |||
| self.conv3 = conv1x1(planes, planes * self.expansion) | |||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics) | |||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
| self.relu = nn.ReLU() | |||
| self.downsample = downsample | |||
| @@ -127,19 +121,17 @@ class Bottleneck(nn.Cell): | |||
| class ASPP(nn.Cell): | |||
| def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21, | |||
| use_batch_statistics=True): | |||
| def __init__(self, atrous_rates, in_channels=2048, num_classes=21): | |||
| super(ASPP, self).__init__() | |||
| self.phase = phase | |||
| out_channels = 256 | |||
| self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics) | |||
| self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics) | |||
| self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics) | |||
| self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics) | |||
| self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics) | |||
| self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0]) | |||
| self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1]) | |||
| self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2]) | |||
| self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3]) | |||
| self.aspp_pooling = ASPPPooling(in_channels, out_channels) | |||
| self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1, | |||
| weight_init='xavier_uniform') | |||
| self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||
| self.bn1 = nn.BatchNorm2d(out_channels) | |||
| self.relu = nn.ReLU() | |||
| self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True) | |||
| self.concat = P.Concat(axis=1) | |||
| @@ -160,18 +152,18 @@ class ASPP(nn.Cell): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| if self.phase == 'train': | |||
| if self.training: | |||
| x = self.drop(x) | |||
| x = self.conv2(x) | |||
| return x | |||
| class ASPPPooling(nn.Cell): | |||
| def __init__(self, in_channels, out_channels, use_batch_statistics=True): | |||
| def __init__(self, in_channels, out_channels): | |||
| super(ASPPPooling, self).__init__() | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'), | |||
| nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics), | |||
| nn.BatchNorm2d(out_channels), | |||
| nn.ReLU() | |||
| ]) | |||
| self.shape = P.Shape() | |||
| @@ -185,14 +177,14 @@ class ASPPPooling(nn.Cell): | |||
| class ASPPConv(nn.Cell): | |||
| def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True): | |||
| def __init__(self, in_channels, out_channels, atrous_rate=1): | |||
| super(ASPPConv, self).__init__() | |||
| if atrous_rate == 1: | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform') | |||
| else: | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate, | |||
| dilation=atrous_rate, weight_init='xavier_uniform') | |||
| bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||
| bn = nn.BatchNorm2d(out_channels) | |||
| relu = nn.ReLU() | |||
| self.aspp_conv = nn.SequentialCell([conv, bn, relu]) | |||
| @@ -202,13 +194,10 @@ class ASPPConv(nn.Cell): | |||
| class DeepLabV3(nn.Cell): | |||
| def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False): | |||
| def __init__(self, num_classes=21, output_stride=16): | |||
| super(DeepLabV3, self).__init__() | |||
| use_batch_statistics = not freeze_bn | |||
| self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride, | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes, | |||
| use_batch_statistics=use_batch_statistics) | |||
| self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride) | |||
| self.aspp = ASPP([1, 6, 12, 18], 2048, num_classes) | |||
| self.shape = P.Shape() | |||
| def construct(self, x): | |||
| @@ -0,0 +1,85 @@ | |||
| import os | |||
| import stat | |||
| from mindspore.train.callback import Callback | |||
| from mindspore import log as logger | |||
| from mindspore import save_checkpoint | |||
| from src.utils.eval_utils import BuildEvalNetwork, net_eval | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| def apply_eval(eval_param_dict): | |||
| network = eval_param_dict["net"] | |||
| network = BuildEvalNetwork(network) | |||
| network.set_train(False) | |||
| args = eval_param_dict["args"] | |||
| args.data_lst = eval_param_dict["dataset"] | |||
| eval_result = net_eval(args, network) | |||
| return eval_result | |||
| @@ -0,0 +1,152 @@ | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| def cal_hist(a, b, n): | |||
| k = (a >= 0) & (a < n) | |||
| return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) | |||
| def resize_long(img, long_size=513): | |||
| h, w, _ = img.shape | |||
| if h > w: | |||
| new_h = long_size | |||
| new_w = int(1.0 * long_size * w / h) | |||
| else: | |||
| new_w = long_size | |||
| new_h = int(1.0 * long_size * h / w) | |||
| imo = cv2.resize(img, (new_w, new_h)) | |||
| return imo | |||
| class BuildEvalNetwork(nn.Cell): | |||
| def __init__(self, network, input_format="NCHW"): | |||
| super(BuildEvalNetwork, self).__init__() | |||
| self.network = network | |||
| self.softmax = nn.Softmax(axis=1) | |||
| self.transpose = ops.Transpose() | |||
| self.format = input_format | |||
| def construct(self, input_data): | |||
| if self.format == "NHWC": | |||
| input_data = self.transpose(input_data, (0, 3, 1, 2)) | |||
| output = self.network(input_data) | |||
| output = self.softmax(output) | |||
| return output | |||
| def pre_process(args, img_, crop_size=513): | |||
| # resize | |||
| img_ = resize_long(img_, crop_size) | |||
| resize_h, resize_w, _ = img_.shape | |||
| # mean, std | |||
| image_mean = np.array(args.image_mean) | |||
| image_std = np.array(args.image_std) | |||
| img_ = (img_ - image_mean) / image_std | |||
| # pad to crop_size | |||
| pad_h = crop_size - img_.shape[0] | |||
| pad_w = crop_size - img_.shape[1] | |||
| if pad_h > 0 or pad_w > 0: | |||
| img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) | |||
| # hwc to chw | |||
| img_ = img_.transpose((2, 0, 1)) | |||
| return img_, resize_h, resize_w | |||
| def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): | |||
| result_lst = [] | |||
| batch_size = len(img_lst) | |||
| batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) | |||
| resize_hw = [] | |||
| for l in range(batch_size): | |||
| img_ = img_lst[l] | |||
| img_, resize_h, resize_w = pre_process(args, img_, crop_size) | |||
| batch_img[l] = img_ | |||
| resize_hw.append([resize_h, resize_w]) | |||
| batch_img = np.ascontiguousarray(batch_img) | |||
| net_out = eval_net(Tensor(batch_img, mstype.float32)) | |||
| net_out = net_out.asnumpy() | |||
| if flip: | |||
| batch_img = batch_img[:, :, :, ::-1] | |||
| net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) | |||
| net_out += net_out_flip.asnumpy()[:, :, :, ::-1] | |||
| for bs in range(batch_size): | |||
| probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) | |||
| ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] | |||
| probs_ = cv2.resize(probs_, (ori_w, ori_h)) | |||
| result_lst.append(probs_) | |||
| return result_lst | |||
| def eval_batch_scales(args, eval_net, img_lst, scales, | |||
| base_crop_size=513, flip=True): | |||
| sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] | |||
| probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) | |||
| print(sizes_) | |||
| for crop_size_ in sizes_[1:]: | |||
| probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) | |||
| for pl, _ in enumerate(probs_lst): | |||
| probs_lst[pl] += probs_lst_tmp[pl] | |||
| result_msk = [] | |||
| for i in probs_lst: | |||
| result_msk.append(i.argmax(axis=2)) | |||
| return result_msk | |||
| def net_eval(args, eval_net): | |||
| # data list | |||
| with open(args.data_lst) as f: | |||
| img_lst = f.readlines() | |||
| # evaluate | |||
| hist = np.zeros((args.num_classes, args.num_classes)) | |||
| batch_img_lst = [] | |||
| batch_msk_lst = [] | |||
| bi = 0 | |||
| image_num = 0 | |||
| for i, line in enumerate(img_lst): | |||
| img_path, msk_path = line.strip().split(' ') | |||
| img_path = os.path.join(args.data_root, img_path) | |||
| msk_path = os.path.join(args.data_root, msk_path) | |||
| img_ = cv2.imread(img_path) | |||
| msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) | |||
| batch_img_lst.append(img_) | |||
| batch_msk_lst.append(msk_) | |||
| bi += 1 | |||
| if bi == args.batch_size: | |||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||
| base_crop_size=args.crop_size, flip=args.flip) | |||
| for mi in range(args.batch_size): | |||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||
| bi = 0 | |||
| batch_img_lst = [] | |||
| batch_msk_lst = [] | |||
| print('processed {} images'.format(i + 1)) | |||
| image_num = i | |||
| if bi > 0: | |||
| batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, | |||
| base_crop_size=args.crop_size, flip=args.flip) | |||
| for mi in range(bi): | |||
| hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) | |||
| print('processed {} images'.format(image_num + 1)) | |||
| print(hist) | |||
| iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | |||
| mean_iou = np.nanmean(iu) | |||
| print('per-class IoU', iu) | |||
| print('mean IoU', mean_iou) | |||
| return mean_iou | |||
| @@ -31,7 +31,8 @@ from src.data import dataset as data_generator | |||
| from src.loss import loss | |||
| from src.nets import net_factory | |||
| from src.utils import learning_rates | |||
| from src.utils.eval_utils import BuildEvalNetwork | |||
| from src.utils.eval_callback import EvalCallBack, apply_eval | |||
| set_seed(1) | |||
| @@ -72,7 +73,6 @@ def parse_args(): | |||
| # model | |||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | |||
| parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') | |||
| parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') | |||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||
| help="Filter the last weight parameters, default is False.") | |||
| @@ -86,6 +86,22 @@ def parse_args(): | |||
| parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving') | |||
| parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving') | |||
| # validate | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=200, | |||
| help="Evaluation start epoch when run_eval is True, default is 200.") | |||
| parser.add_argument("--eval_interval", type=int, default=1, | |||
| help="Evaluation interval when run_eval is True, default is 1.") | |||
| parser.add_argument('--ann_file', type=str, default='', help='path to annotation') | |||
| parser.add_argument('--data_root', type=str, default='', help='root path of val data') | |||
| parser.add_argument('--val_data', type=str, default='', help='list of val data') | |||
| parser.add_argument('--scales', type=float, action='append', help='scales of evaluation') | |||
| parser.add_argument('--flip', action='store_true', help='perform left-right flip') | |||
| parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW", | |||
| help="NCHW or NHWC") | |||
| args, _ = parser.parse_known_args() | |||
| return args | |||
| @@ -126,9 +142,9 @@ def train(): | |||
| # network | |||
| if args.model == 'deeplab_v3_s16': | |||
| network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn) | |||
| network = net_factory.nets_map[args.model](args.num_classes, 16) | |||
| elif args.model == 'deeplab_v3_s8': | |||
| network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn) | |||
| network = net_factory.nets_map[args.model](args.num_classes, 8) | |||
| else: | |||
| raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | |||
| @@ -169,6 +185,7 @@ def train(): | |||
| # loss scale | |||
| manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||
| amp_level = "O0" if args.device_target == "CPU" else "O3" | |||
| train_net.set_train(True) | |||
| model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale) | |||
| # callback for saving ckpts | |||
| @@ -182,6 +199,16 @@ def train(): | |||
| ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) | |||
| cbs.append(ckpoint_cb) | |||
| if args.run_eval and args.rank == 0: | |||
| network_eval = BuildEvalNetwork(network, args.input_format) | |||
| eval_dataset = args.val_data | |||
| save_ckpt_path = args.train_dir | |||
| eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "args": args} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args.eval_interval, | |||
| eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt", | |||
| metrics_name="mIou") | |||
| cbs.append(eval_cb) | |||
| model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU")) | |||