diff --git a/model_zoo/official/cv/deeplabv3/eval.py b/model_zoo/official/cv/deeplabv3/eval.py index e36cbce4a5..b41e133a2d 100644 --- a/model_zoo/official/cv/deeplabv3/eval.py +++ b/model_zoo/official/cv/deeplabv3/eval.py @@ -16,14 +16,12 @@ import os import argparse -import numpy as np -import cv2 -from mindspore import Tensor -import mindspore.common.dtype as mstype -import mindspore.nn as nn + from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.nets import net_factory +from src.utils.eval_utils import BuildEvalNetwork, net_eval + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=int(os.getenv('DEVICE_ID'))) @@ -45,169 +43,24 @@ def parse_args(): # model parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') - parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') + args_space, _ = parser.parse_known_args() + return args_space - args, _ = parser.parse_known_args() - return args - - -def cal_hist(a, b, n): - k = (a >= 0) & (a < n) - return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) - - -def resize_long(img, long_size=513): - h, w, _ = img.shape - if h > w: - new_h = long_size - new_w = int(1.0 * long_size * w / h) - else: - new_w = long_size - new_h = int(1.0 * long_size * h / w) - imo = cv2.resize(img, (new_w, new_h)) - return imo - - -class BuildEvalNetwork(nn.Cell): - def __init__(self, network): - super(BuildEvalNetwork, self).__init__() - self.network = network - self.softmax = nn.Softmax(axis=1) - - def construct(self, input_data): - output = self.network(input_data) - output = self.softmax(output) - return output - - -def pre_process(args, img_, crop_size=513): - # resize - img_ = resize_long(img_, crop_size) - resize_h, resize_w, _ = img_.shape - - # mean, std - image_mean = np.array(args.image_mean) - image_std = np.array(args.image_std) - img_ = (img_ - image_mean) / image_std - - # pad to crop_size - pad_h = crop_size - img_.shape[0] - pad_w = crop_size - img_.shape[1] - if pad_h > 0 or pad_w > 0: - img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) - - # hwc to chw - img_ = img_.transpose((2, 0, 1)) - return img_, resize_h, resize_w - - -def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): - result_lst = [] - batch_size = len(img_lst) - batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) - resize_hw = [] - for l in range(batch_size): - img_ = img_lst[l] - img_, resize_h, resize_w = pre_process(args, img_, crop_size) - batch_img[l] = img_ - resize_hw.append([resize_h, resize_w]) - - batch_img = np.ascontiguousarray(batch_img) - net_out = eval_net(Tensor(batch_img, mstype.float32)) - net_out = net_out.asnumpy() - if flip: - batch_img = batch_img[:, :, :, ::-1] - net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) - net_out += net_out_flip.asnumpy()[:, :, :, ::-1] - - for bs in range(batch_size): - probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) - ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] - probs_ = cv2.resize(probs_, (ori_w, ori_h)) - result_lst.append(probs_) - - return result_lst - - -def eval_batch_scales(args, eval_net, img_lst, scales, - base_crop_size=513, flip=True): - sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] - probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) - print(sizes_) - for crop_size_ in sizes_[1:]: - probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) - for pl, _ in enumerate(probs_lst): - probs_lst[pl] += probs_lst_tmp[pl] - - result_msk = [] - for i in probs_lst: - result_msk.append(i.argmax(axis=2)) - return result_msk - - -def net_eval(): +if __name__ == '__main__': args = parse_args() - - # data list - with open(args.data_lst) as f: - img_lst = f.readlines() - # network if args.model == 'deeplab_v3_s16': - network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn) + network = net_factory.nets_map[args.model](args.num_classes, 16) elif args.model == 'deeplab_v3_s8': - network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn) + network = net_factory.nets_map[args.model](args.num_classes, 8) else: raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) - - eval_net = BuildEvalNetwork(network) - + eval_net = BuildEvalNetwork(network, args.input_format) # load model param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(eval_net, param_dict) eval_net.set_train(False) - # evaluate - hist = np.zeros((args.num_classes, args.num_classes)) - batch_img_lst = [] - batch_msk_lst = [] - bi = 0 - image_num = 0 - for i, line in enumerate(img_lst): - img_path, msk_path = line.strip().split(' ') - img_path = os.path.join(args.data_root, img_path) - msk_path = os.path.join(args.data_root, msk_path) - img_ = cv2.imread(img_path) - msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) - batch_img_lst.append(img_) - batch_msk_lst.append(msk_) - bi += 1 - if bi == args.batch_size: - batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, - base_crop_size=args.crop_size, flip=args.flip) - for mi in range(args.batch_size): - hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) - - bi = 0 - batch_img_lst = [] - batch_msk_lst = [] - print('processed {} images'.format(i+1)) - image_num = i - - if bi > 0: - batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, - base_crop_size=args.crop_size, flip=args.flip) - for mi in range(bi): - hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) - print('processed {} images'.format(image_num + 1)) - - print(hist) - iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) - print('per-class IoU', iu) - print('mean IoU', np.nanmean(iu)) - - -if __name__ == '__main__': - net_eval() + net_eval(args, eval_net) diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh index 5c490de58e..f9d71b2379 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then fi mkdir -p ${train_path} mkdir ${train_path}/ckpt - +''' +If turn on the verification function while training, need to set `data_root` and `data_lst`. +Otherwise, it can be empty string with "". +''' for((i=0;i<=$RANK_SIZE-1;i++)); do export RANK_ID=${i} @@ -37,6 +40,9 @@ do cd ${train_path}/device${DEVICE_ID} || exit python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ --data_file=/PATH/TO/MINDRECORD_NAME \ + --data_root=/PATH/TO/DATA \ + --val_data=/PATH/TO/DATA_lst.txt \ + --scales=1.0 \ --train_epochs=300 \ --batch_size=32 \ --crop_size=513 \ diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh index 2e50fba39a..eb7fa6665c 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then fi mkdir -p ${train_path} mkdir ${train_path}/ckpt - +''' +If turn on the verification function while training, need to set `data_root` and `data_lst`. +Otherwise, it can be empty string with "". +''' for((i=0;i<=$RANK_SIZE-1;i++)); do export RANK_ID=${i} @@ -37,6 +40,9 @@ do cd ${train_path}/device${DEVICE_ID} || exit python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ --data_file=/PATH/TO/MINDRECORD_NAME \ + --data_root=/PATH/TO/DATA \ + --val_data=/PATH/TO/DATA_lst.txt \ + --scales=1.0 \ --train_epochs=800 \ --batch_size=16 \ --crop_size=513 \ @@ -51,5 +57,6 @@ do --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ --is_distributed \ --save_steps=820 \ + --scales=1.0 \ --keep_checkpoint_max=200 >log 2>&1 & done diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh index 0a34002ae9..5e7dcf32b6 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh @@ -27,7 +27,10 @@ if [ -d ${train_path} ]; then fi mkdir -p ${train_path} mkdir ${train_path}/ckpt - +''' +If turn on the verification function while training, need to set `data_root` and `data_lst`. +Otherwise, it can be empty string with "". +''' for((i=0;i<=$RANK_SIZE-1;i++)); do export RANK_ID=${i} @@ -37,6 +40,9 @@ do cd ${train_path}/device${DEVICE_ID} || exit python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ --data_file=/PATH/TO/MINDRECORD_NAME \ + --data_root=/PATH/TO/DATA \ + --val_data=/PATH/TO/DATA_lst.txt \ + --scales=1.0 \ --train_epochs=300 \ --batch_size=16 \ --crop_size=513 \ diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh index 66305e860a..11dd7827f6 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ --num_classes=21 \ --model=deeplab_v3_s16 \ --scales=1.0 \ - --freeze_bn \ --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh index a189089ceb..2a7bd78c08 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh @@ -32,6 +32,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ --num_classes=21 \ --model=deeplab_v3_s8 \ --scales=1.0 \ - --freeze_bn \ --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh index 824d539e3a..b332ef7950 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh @@ -36,6 +36,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ --scales=1.0 \ --scales=1.25 \ --scales=1.75 \ - --freeze_bn \ --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh index 88beb11d6f..2c6159b949 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh @@ -37,6 +37,5 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ --scales=1.25 \ --scales=1.75 \ --flip \ - --freeze_bn \ --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh b/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh index a9a741ff2a..41c104c35e 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh @@ -26,9 +26,15 @@ mkdir -p ${train_path} mkdir ${train_path}/device${DEVICE_ID} mkdir ${train_path}/ckpt cd ${train_path}/device${DEVICE_ID} || exit - +''' +If turn on the verification function while training, need to set `data_root` and `data_lst`. +Otherwise, it can be empty string with "". +''' python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ --train_dir=${train_path}/ckpt \ + --data_root=/PATH/TO/DATA \ + --val_data=/PATH/TO/DATA_lst.txt \ + --scales=1.0 \ --train_epochs=200 \ --batch_size=32 \ --crop_size=513 \ diff --git a/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py b/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py index 354ec9a133..f38677e327 100644 --- a/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py +++ b/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py @@ -27,47 +27,41 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1): class Resnet(nn.Cell): - def __init__(self, block, block_num, output_stride, use_batch_statistics=True): + def __init__(self, block, block_num, output_stride): super(Resnet, self).__init__() self.inplanes = 64 self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3, weight_init='xavier_uniform') - self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics) + self.bn1 = nn.BatchNorm2d(self.inplanes) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') - self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics) - self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics) + self.layer1 = self._make_layer(block, 64, block_num[0]) + self.layer2 = self._make_layer(block, 128, block_num[1], stride=2) if output_stride == 16: - self.layer3 = self._make_layer(block, 256, block_num[2], stride=2, - use_batch_statistics=use_batch_statistics) - self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4], - use_batch_statistics=use_batch_statistics) + self.layer3 = self._make_layer(block, 256, block_num[2], stride=2) + self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4]) elif output_stride == 8: - self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2, - use_batch_statistics=use_batch_statistics) - self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4], - use_batch_statistics=use_batch_statistics) + self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2) + self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4]) - def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True): + def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None): if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.SequentialCell([ conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics) + nn.BatchNorm2d(planes * block.expansion) ]) if grids is None: grids = [1] * blocks layers = [ - block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0], - use_batch_statistics=use_batch_statistics) + block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0]) ] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( - block(self.inplanes, planes, dilation=base_dilation * grids[i], - use_batch_statistics=use_batch_statistics)) + block(self.inplanes, planes, dilation=base_dilation * grids[i])) return nn.SequentialCell(layers) @@ -88,16 +82,16 @@ class Resnet(nn.Cell): class Bottleneck(nn.Cell): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): super(Bottleneck, self).__init__() self.conv1 = conv1x1(inplanes, planes) - self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) + self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) - self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) + self.bn2 = nn.BatchNorm2d(planes) self.conv3 = conv1x1(planes, planes * self.expansion) - self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU() self.downsample = downsample @@ -127,19 +121,17 @@ class Bottleneck(nn.Cell): class ASPP(nn.Cell): - def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21, - use_batch_statistics=True): + def __init__(self, atrous_rates, in_channels=2048, num_classes=21): super(ASPP, self).__init__() - self.phase = phase out_channels = 256 - self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics) - self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics) - self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics) - self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics) - self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics) + self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0]) + self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1]) + self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2]) + self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3]) + self.aspp_pooling = ASPPPooling(in_channels, out_channels) self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1, weight_init='xavier_uniform') - self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) + self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True) self.concat = P.Concat(axis=1) @@ -160,18 +152,18 @@ class ASPP(nn.Cell): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) - if self.phase == 'train': + if self.training: x = self.drop(x) x = self.conv2(x) return x class ASPPPooling(nn.Cell): - def __init__(self, in_channels, out_channels, use_batch_statistics=True): + def __init__(self, in_channels, out_channels): super(ASPPPooling, self).__init__() self.conv = nn.SequentialCell([ nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'), - nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics), + nn.BatchNorm2d(out_channels), nn.ReLU() ]) self.shape = P.Shape() @@ -185,14 +177,14 @@ class ASPPPooling(nn.Cell): class ASPPConv(nn.Cell): - def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True): + def __init__(self, in_channels, out_channels, atrous_rate=1): super(ASPPConv, self).__init__() if atrous_rate == 1: conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform') else: conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate, dilation=atrous_rate, weight_init='xavier_uniform') - bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) + bn = nn.BatchNorm2d(out_channels) relu = nn.ReLU() self.aspp_conv = nn.SequentialCell([conv, bn, relu]) @@ -202,13 +194,10 @@ class ASPPConv(nn.Cell): class DeepLabV3(nn.Cell): - def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False): + def __init__(self, num_classes=21, output_stride=16): super(DeepLabV3, self).__init__() - use_batch_statistics = not freeze_bn - self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride, - use_batch_statistics=use_batch_statistics) - self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes, - use_batch_statistics=use_batch_statistics) + self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride) + self.aspp = ASPP([1, 6, 12, 18], 2048, num_classes) self.shape = P.Shape() def construct(self, x): diff --git a/model_zoo/official/cv/deeplabv3/src/utils/eval_callback.py b/model_zoo/official/cv/deeplabv3/src/utils/eval_callback.py new file mode 100644 index 0000000000..5f607897e6 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/utils/eval_callback.py @@ -0,0 +1,85 @@ +import os +import stat +from mindspore.train.callback import Callback +from mindspore import log as logger +from mindspore import save_checkpoint +from src.utils.eval_utils import BuildEvalNetwork, net_eval + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) + + +def apply_eval(eval_param_dict): + network = eval_param_dict["net"] + network = BuildEvalNetwork(network) + network.set_train(False) + args = eval_param_dict["args"] + args.data_lst = eval_param_dict["dataset"] + eval_result = net_eval(args, network) + return eval_result diff --git a/model_zoo/official/cv/deeplabv3/src/utils/eval_utils.py b/model_zoo/official/cv/deeplabv3/src/utils/eval_utils.py new file mode 100644 index 0000000000..6ba8fcb8f7 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/utils/eval_utils.py @@ -0,0 +1,152 @@ +import os +import cv2 +import numpy as np +from mindspore import Tensor +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops as ops + + +def cal_hist(a, b, n): + k = (a >= 0) & (a < n) + return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) + + +def resize_long(img, long_size=513): + h, w, _ = img.shape + if h > w: + new_h = long_size + new_w = int(1.0 * long_size * w / h) + else: + new_w = long_size + new_h = int(1.0 * long_size * h / w) + imo = cv2.resize(img, (new_w, new_h)) + return imo + + +class BuildEvalNetwork(nn.Cell): + def __init__(self, network, input_format="NCHW"): + super(BuildEvalNetwork, self).__init__() + self.network = network + self.softmax = nn.Softmax(axis=1) + self.transpose = ops.Transpose() + self.format = input_format + + def construct(self, input_data): + if self.format == "NHWC": + input_data = self.transpose(input_data, (0, 3, 1, 2)) + output = self.network(input_data) + output = self.softmax(output) + return output + + +def pre_process(args, img_, crop_size=513): + # resize + img_ = resize_long(img_, crop_size) + resize_h, resize_w, _ = img_.shape + + # mean, std + image_mean = np.array(args.image_mean) + image_std = np.array(args.image_std) + img_ = (img_ - image_mean) / image_std + + # pad to crop_size + pad_h = crop_size - img_.shape[0] + pad_w = crop_size - img_.shape[1] + if pad_h > 0 or pad_w > 0: + img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) + # hwc to chw + img_ = img_.transpose((2, 0, 1)) + return img_, resize_h, resize_w + + +def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): + result_lst = [] + batch_size = len(img_lst) + batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) + resize_hw = [] + for l in range(batch_size): + img_ = img_lst[l] + img_, resize_h, resize_w = pre_process(args, img_, crop_size) + batch_img[l] = img_ + resize_hw.append([resize_h, resize_w]) + + batch_img = np.ascontiguousarray(batch_img) + net_out = eval_net(Tensor(batch_img, mstype.float32)) + net_out = net_out.asnumpy() + + if flip: + batch_img = batch_img[:, :, :, ::-1] + net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) + net_out += net_out_flip.asnumpy()[:, :, :, ::-1] + + for bs in range(batch_size): + probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) + ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] + probs_ = cv2.resize(probs_, (ori_w, ori_h)) + result_lst.append(probs_) + + return result_lst + + +def eval_batch_scales(args, eval_net, img_lst, scales, + base_crop_size=513, flip=True): + sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] + probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) + print(sizes_) + for crop_size_ in sizes_[1:]: + probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) + for pl, _ in enumerate(probs_lst): + probs_lst[pl] += probs_lst_tmp[pl] + + result_msk = [] + for i in probs_lst: + result_msk.append(i.argmax(axis=2)) + return result_msk + + +def net_eval(args, eval_net): + # data list + with open(args.data_lst) as f: + img_lst = f.readlines() + + # evaluate + hist = np.zeros((args.num_classes, args.num_classes)) + batch_img_lst = [] + batch_msk_lst = [] + bi = 0 + image_num = 0 + for i, line in enumerate(img_lst): + img_path, msk_path = line.strip().split(' ') + img_path = os.path.join(args.data_root, img_path) + msk_path = os.path.join(args.data_root, msk_path) + img_ = cv2.imread(img_path) + msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) + batch_img_lst.append(img_) + batch_msk_lst.append(msk_) + bi += 1 + if bi == args.batch_size: + batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, + base_crop_size=args.crop_size, flip=args.flip) + for mi in range(args.batch_size): + hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) + + bi = 0 + batch_img_lst = [] + batch_msk_lst = [] + print('processed {} images'.format(i + 1)) + image_num = i + + if bi > 0: + batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, + base_crop_size=args.crop_size, flip=args.flip) + for mi in range(bi): + hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) + print('processed {} images'.format(image_num + 1)) + + print(hist) + iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + mean_iou = np.nanmean(iu) + print('per-class IoU', iu) + print('mean IoU', mean_iou) + return mean_iou diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py index 4139246d01..a116e847a6 100644 --- a/model_zoo/official/cv/deeplabv3/train.py +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -31,7 +31,8 @@ from src.data import dataset as data_generator from src.loss import loss from src.nets import net_factory from src.utils import learning_rates - +from src.utils.eval_utils import BuildEvalNetwork +from src.utils.eval_callback import EvalCallBack, apply_eval set_seed(1) @@ -72,7 +73,6 @@ def parse_args(): # model parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') - parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, help="Filter the last weight parameters, default is False.") @@ -86,6 +86,22 @@ def parse_args(): parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving') parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving') + # validate + parser.add_argument("--run_eval", type=ast.literal_eval, default=False, + help="Run evaluation when training, default is False.") + parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, + help="Save best checkpoint when run_eval is True, default is True.") + parser.add_argument("--eval_start_epoch", type=int, default=200, + help="Evaluation start epoch when run_eval is True, default is 200.") + parser.add_argument("--eval_interval", type=int, default=1, + help="Evaluation interval when run_eval is True, default is 1.") + parser.add_argument('--ann_file', type=str, default='', help='path to annotation') + parser.add_argument('--data_root', type=str, default='', help='root path of val data') + parser.add_argument('--val_data', type=str, default='', help='list of val data') + parser.add_argument('--scales', type=float, action='append', help='scales of evaluation') + parser.add_argument('--flip', action='store_true', help='perform left-right flip') + parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW", + help="NCHW or NHWC") args, _ = parser.parse_known_args() return args @@ -126,9 +142,9 @@ def train(): # network if args.model == 'deeplab_v3_s16': - network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn) + network = net_factory.nets_map[args.model](args.num_classes, 16) elif args.model == 'deeplab_v3_s8': - network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn) + network = net_factory.nets_map[args.model](args.num_classes, 8) else: raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) @@ -169,6 +185,7 @@ def train(): # loss scale manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) amp_level = "O0" if args.device_target == "CPU" else "O3" + train_net.set_train(True) model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale) # callback for saving ckpts @@ -182,6 +199,16 @@ def train(): ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) cbs.append(ckpoint_cb) + if args.run_eval and args.rank == 0: + network_eval = BuildEvalNetwork(network, args.input_format) + eval_dataset = args.val_data + save_ckpt_path = args.train_dir + eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "args": args} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args.eval_interval, + eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt", + metrics_name="mIou") + cbs.append(eval_cb) model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU"))