@@ -0,0 +1,47 @@ | |||||
# Dataset | |||||
Dataset used: [COCO2017](<https://cocodataset.org/>) | |||||
- Dataset size:19G | |||||
- Train:18G,118000 images | |||||
- Val:1G,5000 images | |||||
- Annotations:241M,instances,captions,person_keypoints etc | |||||
- Data format:image and json files | |||||
- Note:Data will be processed in dataset.py | |||||
# Environment Requirements | |||||
- Install [MindSpore](https://www.mindspore.cn/install/en). | |||||
- Download the dataset COCO2017. | |||||
- We use COCO2017 as dataset in this example. | |||||
Install Cython and pycocotool, and you can also install mmcv to process data. | |||||
``` | |||||
pip install Cython | |||||
pip install pycocotools | |||||
pip install mmcv==0.2.14 | |||||
``` | |||||
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: | |||||
``` | |||||
. | |||||
└─cocodataset | |||||
├─annotations | |||||
├─instance_train2017.json | |||||
└─instance_val2017.json | |||||
├─val2017 | |||||
└─train2017 | |||||
``` | |||||
# Quick start | |||||
You can download the pre-trained model checkpoint file [here](<https://www.mindspore.cn/resources/hub/details?2505/MindSpore/ascend/0.7/fasterrcnn_v1.0_coco2017>). | |||||
``` | |||||
python coco_attack_pgd.py --ann_file [VAL_JSON_FILE] --pre_trained [PRETRAINED_CHECKPOINT_FILE] | |||||
``` | |||||
> Adversarial samples will be generated and saved as pickle file. |
@@ -0,0 +1,135 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
"""PGD attack for faster rcnn""" | |||||
import os | |||||
import argparse | |||||
import pickle | |||||
from mindspore import context | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindspore.common import set_seed | |||||
from mindspore.nn import Cell | |||||
from mindspore.ops.composite import GradOperation | |||||
from mindarmour.adv_robustness.attacks import ProjectedGradientDescent | |||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | |||||
from src.config import config | |||||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset | |||||
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name | |||||
set_seed(1) | |||||
parser = argparse.ArgumentParser(description='FasterRCNN attack') | |||||
parser.add_argument('--ann_file', type=str, required=True, help='Ann file path.') | |||||
parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.') | |||||
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.') | |||||
parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.') | |||||
args = parser.parse_args() | |||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id) | |||||
class LossNet(Cell): | |||||
"""loss function.""" | |||||
def construct(self, x1, x2, x3, x4, x5, x6): | |||||
return x4 + x6 | |||||
class WithLossCell(Cell): | |||||
"""Wrap the network with loss function.""" | |||||
def __init__(self, backbone, loss_fn): | |||||
super(WithLossCell, self).__init__(auto_prefix=False) | |||||
self._backbone = backbone | |||||
self._loss_fn = loss_fn | |||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | |||||
@property | |||||
def backbone_network(self): | |||||
return self._backbone | |||||
class GradWrapWithLoss(Cell): | |||||
""" | |||||
Construct a network to compute the gradient of loss function in \ | |||||
input space and weighted by `weight`. | |||||
""" | |||||
def __init__(self, network): | |||||
super(GradWrapWithLoss, self).__init__() | |||||
self._grad_all = GradOperation(get_all=True, sens_param=False) | |||||
self._network = network | |||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||||
gout = self._grad_all(self._network)(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||||
return gout[0] | |||||
if __name__ == '__main__': | |||||
prefix = 'FasterRcnn_eval.mindrecord' | |||||
mindrecord_dir = config.mindrecord_dir | |||||
mindrecord_file = os.path.join(mindrecord_dir, prefix) | |||||
pre_trained = args.pre_trained | |||||
ann_file = args.ann_file | |||||
print("CHECKING MINDRECORD FILES ...") | |||||
if not os.path.exists(mindrecord_file): | |||||
if not os.path.isdir(mindrecord_dir): | |||||
os.makedirs(mindrecord_dir) | |||||
if os.path.isdir(config.coco_root): | |||||
print("Create Mindrecord. It may take some time.") | |||||
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) | |||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
else: | |||||
print("coco_root not exits.") | |||||
print('Start generate adversarial samples.') | |||||
# build network and dataset | |||||
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ | |||||
repeat_num=1, is_training=True) | |||||
net = Faster_Rcnn_Resnet50(config) | |||||
param_dict = load_checkpoint(pre_trained) | |||||
load_param_into_net(net, param_dict) | |||||
net = net.set_train() | |||||
# build attacker | |||||
with_loss_cell = WithLossCell(net, LossNet()) | |||||
grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | |||||
attack = ProjectedGradientDescent(grad_with_loss_net, bounds=None, eps=0.1) | |||||
# generate adversarial samples | |||||
num = args.num | |||||
num_batches = num // config.test_batch_size | |||||
channel = 3 | |||||
adv_samples = [0] * (num_batches * config.test_batch_size) | |||||
adv_id = 0 | |||||
for data in ds.create_dict_iterator(num_epochs=num_batches): | |||||
img_data = data['image'] | |||||
img_metas = data['image_shape'] | |||||
gt_bboxes = data['box'] | |||||
gt_labels = data['label'] | |||||
gt_num = data['valid_num'] | |||||
adv_img = attack.generate(img_data.asnumpy(), \ | |||||
(img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy())) | |||||
for item in adv_img: | |||||
adv_samples[adv_id] = item | |||||
adv_id += 1 | |||||
pickle.dump(adv_samples, open('adv_samples.pkl', 'wb')) | |||||
print('Generate adversarial samples complete.') |
@@ -0,0 +1,31 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn Init.""" | |||||
from .resnet50 import ResNetFea, ResidualBlockUsing | |||||
from .bbox_assign_sample import BboxAssignSample | |||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | |||||
from .fpn_neck import FeatPyramidNeck | |||||
from .proposal_generator import Proposal | |||||
from .rcnn import Rcnn | |||||
from .rpn import RPN | |||||
from .roi_align import SingleRoIExtractor | |||||
from .anchor_generator import AnchorGenerator | |||||
__all__ = [ | |||||
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn", | |||||
"FeatPyramidNeck", "Proposal", "Rcnn", | |||||
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing" | |||||
] |
@@ -0,0 +1,84 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn anchor generator.""" | |||||
import numpy as np | |||||
class AnchorGenerator(): | |||||
"""Anchor generator for FasterRcnn.""" | |||||
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): | |||||
"""Anchor generator init method.""" | |||||
self.base_size = base_size | |||||
self.scales = np.array(scales) | |||||
self.ratios = np.array(ratios) | |||||
self.scale_major = scale_major | |||||
self.ctr = ctr | |||||
self.base_anchors = self.gen_base_anchors() | |||||
def gen_base_anchors(self): | |||||
"""Generate a single anchor.""" | |||||
w = self.base_size | |||||
h = self.base_size | |||||
if self.ctr is None: | |||||
x_ctr = 0.5 * (w - 1) | |||||
y_ctr = 0.5 * (h - 1) | |||||
else: | |||||
x_ctr, y_ctr = self.ctr | |||||
h_ratios = np.sqrt(self.ratios) | |||||
w_ratios = 1 / h_ratios | |||||
if self.scale_major: | |||||
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1) | |||||
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1) | |||||
else: | |||||
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1) | |||||
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1) | |||||
base_anchors = np.stack( | |||||
[ | |||||
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), | |||||
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) | |||||
], | |||||
axis=-1).round() | |||||
return base_anchors | |||||
def _meshgrid(self, x, y, row_major=True): | |||||
"""Generate grid.""" | |||||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1) | |||||
yy = np.repeat(y, len(x)) | |||||
if row_major: | |||||
return xx, yy | |||||
return yy, xx | |||||
def grid_anchors(self, featmap_size, stride=16): | |||||
"""Generate anchor list.""" | |||||
base_anchors = self.base_anchors | |||||
feat_h, feat_w = featmap_size | |||||
shift_x = np.arange(0, feat_w) * stride | |||||
shift_y = np.arange(0, feat_h) * stride | |||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) | |||||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) | |||||
shifts = shifts.astype(base_anchors.dtype) | |||||
# first feat_w elements correspond to the first row of shifts | |||||
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get | |||||
# shifted anchors (K, A, 4), reshape to (K*A, 4) | |||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :] | |||||
all_anchors = all_anchors.reshape(-1, 4) | |||||
return all_anchors |
@@ -0,0 +1,166 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn positive and negative sample screening for RPN.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
import mindspore.common.dtype as mstype | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
class BboxAssignSample(nn.Cell): | |||||
""" | |||||
Bbox assigner and sampler defination. | |||||
Args: | |||||
config (dict): Config. | |||||
batch_size (int): Batchsize. | |||||
num_bboxes (int): The anchor nums. | |||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag. | |||||
Returns: | |||||
Tensor, output tensor. | |||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4) | |||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1) | |||||
labels: label for every bboxes, (batch_size, num_bboxes, 1) | |||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) | |||||
Examples: | |||||
BboxAssignSample(config, 2, 1024, True) | |||||
""" | |||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | |||||
super(BboxAssignSample, self).__init__() | |||||
cfg = config | |||||
self.batch_size = batch_size | |||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) | |||||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) | |||||
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) | |||||
self.zero_thr = Tensor(0.0, mstype.float16) | |||||
self.num_bboxes = num_bboxes | |||||
self.num_gts = cfg.num_gts | |||||
self.num_expected_pos = cfg.num_expected_pos | |||||
self.num_expected_neg = cfg.num_expected_neg | |||||
self.add_gt_as_proposals = add_gt_as_proposals | |||||
if self.add_gt_as_proposals: | |||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1)) | |||||
self.concat = P.Concat(axis=0) | |||||
self.max_gt = P.ArgMaxWithValue(axis=0) | |||||
self.max_anchor = P.ArgMaxWithValue(axis=1) | |||||
self.sum_inds = P.ReduceSum() | |||||
self.iou = P.IOU() | |||||
self.greaterequal = P.GreaterEqual() | |||||
self.greater = P.Greater() | |||||
self.select = P.Select() | |||||
self.gatherND = P.GatherNd() | |||||
self.squeeze = P.Squeeze() | |||||
self.cast = P.Cast() | |||||
self.logicaland = P.LogicalAnd() | |||||
self.less = P.Less() | |||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) | |||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | |||||
self.reshape = P.Reshape() | |||||
self.equal = P.Equal() | |||||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) | |||||
self.scatterNdUpdate = P.ScatterNdUpdate() | |||||
self.scatterNd = P.ScatterNd() | |||||
self.logicalnot = P.LogicalNot() | |||||
self.tile = P.Tile() | |||||
self.zeros_like = P.ZerosLike() | |||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | |||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | |||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | |||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | |||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one) | |||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ | |||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two) | |||||
overlaps = self.iou(bboxes, gt_bboxes_i) | |||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) | |||||
_, max_overlaps_w_ac = self.max_anchor(overlaps) | |||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \ | |||||
self.less(max_overlaps_w_gt, self.neg_iou_thr)) | |||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) | |||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr) | |||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ | |||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) | |||||
assigned_gt_inds4 = assigned_gt_inds3 | |||||
for j in range(self.num_gts): | |||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] | |||||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::]) | |||||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \ | |||||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j)) | |||||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4) | |||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores) | |||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | |||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||||
pos_check_valid = self.sum_inds(pos_check_valid, -1) | |||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | |||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | |||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones | |||||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32) | |||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1)) | |||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | |||||
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) | |||||
num_pos = self.sum_inds(num_pos, -1) | |||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos) | |||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | |||||
pos_bboxes_ = self.gatherND(bboxes, pos_index) | |||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) | |||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) | |||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) | |||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32) | |||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32) | |||||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4)) | |||||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,)) | |||||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,)) | |||||
total_index = self.concat((pos_index, neg_index)) | |||||
total_valid_index = self.concat((valid_pos_index, valid_neg_index)) | |||||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,)) | |||||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \ | |||||
labels_total, self.cast(label_weights_total, mstype.bool_) |
@@ -0,0 +1,197 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn tpositive and negative sample screening for Rcnn.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
class BboxAssignSampleForRcnn(nn.Cell): | |||||
""" | |||||
Bbox assigner and sampler defination. | |||||
Args: | |||||
config (dict): Config. | |||||
batch_size (int): Batchsize. | |||||
num_bboxes (int): The anchor nums. | |||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag. | |||||
Returns: | |||||
Tensor, output tensor. | |||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4) | |||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1) | |||||
labels: label for every bboxes, (batch_size, num_bboxes, 1) | |||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) | |||||
Examples: | |||||
BboxAssignSampleForRcnn(config, 2, 1024, True) | |||||
""" | |||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | |||||
super(BboxAssignSampleForRcnn, self).__init__() | |||||
cfg = config | |||||
self.batch_size = batch_size | |||||
self.neg_iou_thr = cfg.neg_iou_thr_stage2 | |||||
self.pos_iou_thr = cfg.pos_iou_thr_stage2 | |||||
self.min_pos_iou = cfg.min_pos_iou_stage2 | |||||
self.num_gts = cfg.num_gts | |||||
self.num_bboxes = num_bboxes | |||||
self.num_expected_pos = cfg.num_expected_pos_stage2 | |||||
self.num_expected_neg = cfg.num_expected_neg_stage2 | |||||
self.num_expected_total = cfg.num_expected_total_stage2 | |||||
self.add_gt_as_proposals = add_gt_as_proposals | |||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32)) | |||||
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts), | |||||
dtype=np.int32)) | |||||
self.concat = P.Concat(axis=0) | |||||
self.max_gt = P.ArgMaxWithValue(axis=0) | |||||
self.max_anchor = P.ArgMaxWithValue(axis=1) | |||||
self.sum_inds = P.ReduceSum() | |||||
self.iou = P.IOU() | |||||
self.greaterequal = P.GreaterEqual() | |||||
self.greater = P.Greater() | |||||
self.select = P.Select() | |||||
self.gatherND = P.GatherNd() | |||||
self.squeeze = P.Squeeze() | |||||
self.cast = P.Cast() | |||||
self.logicaland = P.LogicalAnd() | |||||
self.less = P.Less() | |||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) | |||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | |||||
self.reshape = P.Reshape() | |||||
self.equal = P.Equal() | |||||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2)) | |||||
self.concat_axis1 = P.Concat(axis=1) | |||||
self.logicalnot = P.LogicalNot() | |||||
self.tile = P.Tile() | |||||
# Check | |||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||||
# Init tensor | |||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | |||||
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) | |||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | |||||
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) | |||||
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) | |||||
self.reshape_shape_pos = (self.num_expected_pos, 1) | |||||
self.reshape_shape_neg = (self.num_expected_neg, 1) | |||||
self.scalar_zero = Tensor(0.0, dtype=mstype.float16) | |||||
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16) | |||||
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16) | |||||
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16) | |||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | |||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | |||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), \ | |||||
gt_bboxes_i, self.check_gt_one) | |||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ | |||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \ | |||||
bboxes, self.check_anchor_two) | |||||
overlaps = self.iou(bboxes, gt_bboxes_i) | |||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) | |||||
_, max_overlaps_w_ac = self.max_anchor(overlaps) | |||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, | |||||
self.scalar_zero), | |||||
self.less(max_overlaps_w_gt, | |||||
self.scalar_neg_iou_thr)) | |||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) | |||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr) | |||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ | |||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) | |||||
for j in range(self.num_gts): | |||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] | |||||
overlaps_w_ac_j = overlaps[j:j+1:1, ::] | |||||
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou) | |||||
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j)) | |||||
pos_mask_j = self.logicaland(temp1, temp2) | |||||
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3) | |||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores) | |||||
bboxes = self.concat((gt_bboxes_i, bboxes)) | |||||
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores) | |||||
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid | |||||
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5)) | |||||
# Get pos index | |||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | |||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||||
pos_check_valid = self.sum_inds(pos_check_valid, -1) | |||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | |||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | |||||
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1) | |||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32) | |||||
pos_index = self.reshape(pos_index, self.reshape_shape_pos) | |||||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | |||||
pos_index = pos_index * valid_pos_index | |||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones | |||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) | |||||
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index | |||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) | |||||
# Get neg index | |||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | |||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos) | |||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | |||||
neg_index = self.reshape(neg_index, self.reshape_shape_neg) | |||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32) | |||||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) | |||||
neg_index = neg_index * valid_neg_index | |||||
pos_bboxes_ = self.gatherND(bboxes, pos_index) | |||||
neg_bboxes_ = self.gatherND(bboxes, neg_index) | |||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) | |||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) | |||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) | |||||
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_)) | |||||
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask)) | |||||
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask)) | |||||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | |||||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) | |||||
total_mask = self.concat((valid_pos_index, valid_neg_index)) | |||||
return total_bboxes, total_deltas, total_labels, total_mask |
@@ -0,0 +1,428 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn based on ResNet50.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.ops import functional as F | |||||
from .resnet50 import ResNetFea, ResidualBlockUsing | |||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | |||||
from .fpn_neck import FeatPyramidNeck | |||||
from .proposal_generator import Proposal | |||||
from .rcnn import Rcnn | |||||
from .rpn import RPN | |||||
from .roi_align import SingleRoIExtractor | |||||
from .anchor_generator import AnchorGenerator | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
class Faster_Rcnn_Resnet50(nn.Cell): | |||||
""" | |||||
FasterRcnn Network. | |||||
Note: | |||||
backbone = resnet50 | |||||
Returns: | |||||
Tuple, tuple of output tensor. | |||||
rpn_loss: Scalar, Total loss of RPN subnet. | |||||
rcnn_loss: Scalar, Total loss of RCNN subnet. | |||||
rpn_cls_loss: Scalar, Classification loss of RPN subnet. | |||||
rpn_reg_loss: Scalar, Regression loss of RPN subnet. | |||||
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet. | |||||
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet. | |||||
Examples: | |||||
net = Faster_Rcnn_Resnet50() | |||||
""" | |||||
def __init__(self, config): | |||||
super(Faster_Rcnn_Resnet50, self).__init__() | |||||
self.train_batch_size = config.batch_size | |||||
self.num_classes = config.num_classes | |||||
self.anchor_scales = config.anchor_scales | |||||
self.anchor_ratios = config.anchor_ratios | |||||
self.anchor_strides = config.anchor_strides | |||||
self.target_means = tuple(config.rcnn_target_means) | |||||
self.target_stds = tuple(config.rcnn_target_stds) | |||||
# Anchor generator | |||||
anchor_base_sizes = None | |||||
self.anchor_base_sizes = list( | |||||
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes | |||||
self.anchor_generators = [] | |||||
for anchor_base in self.anchor_base_sizes: | |||||
self.anchor_generators.append( | |||||
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios)) | |||||
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) | |||||
featmap_sizes = config.feature_shapes | |||||
assert len(featmap_sizes) == len(self.anchor_generators) | |||||
self.anchor_list = self.get_anchors(featmap_sizes) | |||||
# Backbone resnet50 | |||||
self.backbone = ResNetFea(ResidualBlockUsing, | |||||
config.resnet_block, | |||||
config.resnet_in_channels, | |||||
config.resnet_out_channels, | |||||
False) | |||||
# Fpn | |||||
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels, | |||||
config.fpn_out_channels, | |||||
config.fpn_num_outs) | |||||
# Rpn and rpn loss | |||||
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8)) | |||||
self.rpn_with_loss = RPN(config, | |||||
self.train_batch_size, | |||||
config.rpn_in_channels, | |||||
config.rpn_feat_channels, | |||||
config.num_anchors, | |||||
config.rpn_cls_out_channels) | |||||
# Proposal | |||||
self.proposal_generator = Proposal(config, | |||||
self.train_batch_size, | |||||
config.activate_num_classes, | |||||
config.use_sigmoid_cls) | |||||
self.proposal_generator.set_train_local(config, True) | |||||
self.proposal_generator_test = Proposal(config, | |||||
config.test_batch_size, | |||||
config.activate_num_classes, | |||||
config.use_sigmoid_cls) | |||||
self.proposal_generator_test.set_train_local(config, False) | |||||
# Assign and sampler stage two | |||||
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size, | |||||
config.num_bboxes_stage2, True) | |||||
self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \ | |||||
stds=self.target_stds) | |||||
# Roi | |||||
self.roi_align = SingleRoIExtractor(config, | |||||
config.roi_layer, | |||||
config.roi_align_out_channels, | |||||
config.roi_align_featmap_strides, | |||||
self.train_batch_size, | |||||
config.roi_align_finest_scale) | |||||
self.roi_align.set_train_local(config, True) | |||||
self.roi_align_test = SingleRoIExtractor(config, | |||||
config.roi_layer, | |||||
config.roi_align_out_channels, | |||||
config.roi_align_featmap_strides, | |||||
1, | |||||
config.roi_align_finest_scale) | |||||
self.roi_align_test.set_train_local(config, False) | |||||
# Rcnn | |||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'], | |||||
self.train_batch_size, self.num_classes) | |||||
# Op declare | |||||
self.squeeze = P.Squeeze() | |||||
self.cast = P.Cast() | |||||
self.concat = P.Concat(axis=0) | |||||
self.concat_1 = P.Concat(axis=1) | |||||
self.concat_2 = P.Concat(axis=2) | |||||
self.reshape = P.Reshape() | |||||
self.select = P.Select() | |||||
self.greater = P.Greater() | |||||
self.transpose = P.Transpose() | |||||
# Test mode | |||||
self.test_batch_size = config.test_batch_size | |||||
self.split = P.Split(axis=0, output_num=self.test_batch_size) | |||||
self.split_shape = P.Split(axis=0, output_num=4) | |||||
self.split_scores = P.Split(axis=1, output_num=self.num_classes) | |||||
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1) | |||||
self.tile = P.Tile() | |||||
self.gather = P.GatherNd() | |||||
self.rpn_max_num = config.rpn_max_num | |||||
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16)) | |||||
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) | |||||
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) | |||||
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, | |||||
self.ones_mask, self.zeros_mask), axis=1)) | |||||
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, | |||||
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) | |||||
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr) | |||||
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0) | |||||
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1) | |||||
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr) | |||||
self.test_max_per_img = config.test_max_per_img | |||||
self.nms_test = P.NMSWithMask(config.test_iou_thr) | |||||
self.softmax = P.Softmax(axis=1) | |||||
self.logicand = P.LogicalAnd() | |||||
self.oneslike = P.OnesLike() | |||||
self.test_topk = P.TopK(sorted=True) | |||||
self.test_num_proposal = self.test_batch_size * self.rpn_max_num | |||||
# Improve speed | |||||
self.concat_start = min(self.num_classes - 2, 55) | |||||
self.concat_end = (self.num_classes - 1) | |||||
# Init tensor | |||||
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, | |||||
dtype=np.float16) for i in range(self.train_batch_size)] | |||||
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \ | |||||
for i in range(self.test_batch_size)] | |||||
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) | |||||
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test)) | |||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids): | |||||
x = self.backbone(img_data) | |||||
x = self.fpn_ncek(x) | |||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x, | |||||
img_metas, | |||||
self.anchor_list, | |||||
gt_bboxes, | |||||
self.gt_labels_stage1, | |||||
gt_valids) | |||||
if self.training: | |||||
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list) | |||||
else: | |||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list) | |||||
gt_labels = self.cast(gt_labels, mstype.int32) | |||||
gt_valids = self.cast(gt_valids, mstype.int32) | |||||
bboxes_tuple = () | |||||
deltas_tuple = () | |||||
labels_tuple = () | |||||
mask_tuple = () | |||||
if self.training: | |||||
for i in range(self.train_batch_size): | |||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) | |||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) | |||||
gt_labels_i = self.cast(gt_labels_i, mstype.uint8) | |||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) | |||||
gt_valids_i = self.cast(gt_valids_i, mstype.bool_) | |||||
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, | |||||
gt_labels_i, | |||||
proposal_mask[i], | |||||
proposal[i][::, 0:4:1], | |||||
gt_valids_i) | |||||
bboxes_tuple += (bboxes,) | |||||
deltas_tuple += (deltas,) | |||||
labels_tuple += (labels,) | |||||
mask_tuple += (mask,) | |||||
bbox_targets = self.concat(deltas_tuple) | |||||
rcnn_labels = self.concat(labels_tuple) | |||||
bbox_targets = F.stop_gradient(bbox_targets) | |||||
rcnn_labels = F.stop_gradient(rcnn_labels) | |||||
rcnn_labels = self.cast(rcnn_labels, mstype.int32) | |||||
else: | |||||
mask_tuple += proposal_mask | |||||
bbox_targets = proposal_mask | |||||
rcnn_labels = proposal_mask | |||||
for p_i in proposal: | |||||
bboxes_tuple += (p_i[::, 0:4:1],) | |||||
if self.training: | |||||
if self.train_batch_size > 1: | |||||
bboxes_all = self.concat(bboxes_tuple) | |||||
else: | |||||
bboxes_all = bboxes_tuple[0] | |||||
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all)) | |||||
else: | |||||
if self.test_batch_size > 1: | |||||
bboxes_all = self.concat(bboxes_tuple) | |||||
else: | |||||
bboxes_all = bboxes_tuple[0] | |||||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) | |||||
rois = self.cast(rois, mstype.float32) | |||||
rois = F.stop_gradient(rois) | |||||
if self.training: | |||||
roi_feats = self.roi_align(rois, | |||||
self.cast(x[0], mstype.float32), | |||||
self.cast(x[1], mstype.float32), | |||||
self.cast(x[2], mstype.float32), | |||||
self.cast(x[3], mstype.float32)) | |||||
else: | |||||
roi_feats = self.roi_align_test(rois, | |||||
self.cast(x[0], mstype.float32), | |||||
self.cast(x[1], mstype.float32), | |||||
self.cast(x[2], mstype.float32), | |||||
self.cast(x[3], mstype.float32)) | |||||
roi_feats = self.cast(roi_feats, mstype.float16) | |||||
rcnn_masks = self.concat(mask_tuple) | |||||
rcnn_masks = F.stop_gradient(rcnn_masks) | |||||
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) | |||||
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats, | |||||
bbox_targets, | |||||
rcnn_labels, | |||||
rcnn_mask_squeeze) | |||||
output = () | |||||
if self.training: | |||||
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss) | |||||
else: | |||||
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas) | |||||
return output | |||||
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas): | |||||
"""Get the actual detection box.""" | |||||
scores = self.softmax(cls_logits) | |||||
boxes_all = () | |||||
for i in range(self.num_classes): | |||||
k = i * 4 | |||||
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1]) | |||||
out_boxes_i = self.decode(rois, reg_logits_i) | |||||
boxes_all += (out_boxes_i,) | |||||
img_metas_all = self.split(img_metas) | |||||
scores_all = self.split(scores) | |||||
mask_all = self.split(self.cast(mask_logits, mstype.int32)) | |||||
boxes_all_with_batchsize = () | |||||
for i in range(self.test_batch_size): | |||||
scale = self.split_shape(self.squeeze(img_metas_all[i])) | |||||
scale_h = scale[2] | |||||
scale_w = scale[3] | |||||
boxes_tuple = () | |||||
for j in range(self.num_classes): | |||||
boxes_tmp = self.split(boxes_all[j]) | |||||
out_boxes_h = boxes_tmp[i] / scale_h | |||||
out_boxes_w = boxes_tmp[i] / scale_w | |||||
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),) | |||||
boxes_all_with_batchsize += (boxes_tuple,) | |||||
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all) | |||||
return output | |||||
def multiclass_nms(self, boxes_all, scores_all, mask_all): | |||||
"""Multiscale postprocessing.""" | |||||
all_bboxes = () | |||||
all_labels = () | |||||
all_masks = () | |||||
for i in range(self.test_batch_size): | |||||
bboxes = boxes_all[i] | |||||
scores = scores_all[i] | |||||
masks = self.cast(mask_all[i], mstype.bool_) | |||||
res_boxes_tuple = () | |||||
res_labels_tuple = () | |||||
res_masks_tuple = () | |||||
for j in range(self.num_classes - 1): | |||||
k = j + 1 | |||||
_cls_scores = scores[::, k:k + 1:1] | |||||
_bboxes = self.squeeze(bboxes[k]) | |||||
_mask_o = self.reshape(masks, (self.rpn_max_num, 1)) | |||||
cls_mask = self.greater(_cls_scores, self.test_score_thresh) | |||||
_mask = self.logicand(_mask_o, cls_mask) | |||||
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_) | |||||
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros) | |||||
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros) | |||||
__cls_scores = self.squeeze(_cls_scores) | |||||
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num) | |||||
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1)) | |||||
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1)) | |||||
_bboxes_sorted = self.gather(_bboxes, topk_inds) | |||||
_mask_sorted = self.gather(_mask, topk_inds) | |||||
scores_sorted = self.tile(scores_sorted, (1, 4)) | |||||
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted)) | |||||
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5)) | |||||
cls_dets, _index, _mask_nms = self.nms_test(cls_dets) | |||||
_index = self.reshape(_index, (self.rpn_max_num, 1)) | |||||
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1)) | |||||
_mask_n = self.gather(_mask_sorted, _index) | |||||
_mask_n = self.logicand(_mask_n, _mask_nms) | |||||
cls_labels = self.oneslike(_index) * j | |||||
res_boxes_tuple += (cls_dets,) | |||||
res_labels_tuple += (cls_labels,) | |||||
res_masks_tuple += (_mask_n,) | |||||
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start]) | |||||
res_labels_start = self.concat(res_labels_tuple[:self.concat_start]) | |||||
res_masks_start = self.concat(res_masks_tuple[:self.concat_start]) | |||||
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end]) | |||||
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end]) | |||||
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end]) | |||||
res_boxes = self.concat((res_boxes_start, res_boxes_end)) | |||||
res_labels = self.concat((res_labels_start, res_labels_end)) | |||||
res_masks = self.concat((res_masks_start, res_masks_end)) | |||||
reshape_size = (self.num_classes - 1) * self.rpn_max_num | |||||
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5)) | |||||
res_labels = self.reshape(res_labels, (1, reshape_size, 1)) | |||||
res_masks = self.reshape(res_masks, (1, reshape_size, 1)) | |||||
all_bboxes += (res_boxes,) | |||||
all_labels += (res_labels,) | |||||
all_masks += (res_masks,) | |||||
all_bboxes = self.concat(all_bboxes) | |||||
all_labels = self.concat(all_labels) | |||||
all_masks = self.concat(all_masks) | |||||
return all_bboxes, all_labels, all_masks | |||||
def get_anchors(self, featmap_sizes): | |||||
"""Get anchors according to feature map sizes. | |||||
Args: | |||||
featmap_sizes (list[tuple]): Multi-level feature map sizes. | |||||
img_metas (list[dict]): Image meta info. | |||||
Returns: | |||||
tuple: anchors of each image, valid flags of each image | |||||
""" | |||||
num_levels = len(featmap_sizes) | |||||
# since feature map sizes of all images are the same, we only compute | |||||
# anchors for one time | |||||
multi_level_anchors = () | |||||
for i in range(num_levels): | |||||
anchors = self.anchor_generators[i].grid_anchors( | |||||
featmap_sizes[i], self.anchor_strides[i]) | |||||
multi_level_anchors += (Tensor(anchors.astype(np.float16)),) | |||||
return multi_level_anchors |
@@ -0,0 +1,114 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn feature pyramid network.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
from mindspore import context | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
from mindspore.common import dtype as mstype | |||||
from mindspore.common.initializer import initializer | |||||
# pylint: disable=locally-disabled, missing-docstring | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
def bias_init_zeros(shape): | |||||
"""Bias init method.""" | |||||
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16)) | |||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||||
"""Conv2D wrapper.""" | |||||
shape = (out_channels, in_channels, kernel_size, kernel_size) | |||||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||||
shape_bias = (out_channels,) | |||||
biass = bias_init_zeros(shape_bias) | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass) | |||||
class FeatPyramidNeck(nn.Cell): | |||||
""" | |||||
Feature pyramid network cell, usually uses as network neck. | |||||
Applies the convolution on multiple, input feature maps | |||||
and output feature map with same channel size. if required num of | |||||
output larger then num of inputs, add extra maxpooling for further | |||||
downsampling; | |||||
Args: | |||||
in_channels (tuple) - Channel size of input feature maps. | |||||
out_channels (int) - Channel size output. | |||||
num_outs (int) - Num of output features. | |||||
Returns: | |||||
Tuple, with tensors of same channel size. | |||||
Examples: | |||||
neck = FeatPyramidNeck([100,200,300], 50, 4) | |||||
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)), | |||||
dtype=np.float32) \ | |||||
for i, c in enumerate(config.fpn_in_channels)) | |||||
x = neck(input_data) | |||||
""" | |||||
def __init__(self, | |||||
in_channels, | |||||
out_channels, | |||||
num_outs): | |||||
super(FeatPyramidNeck, self).__init__() | |||||
self.num_outs = num_outs | |||||
self.in_channels = in_channels | |||||
self.fpn_layer = len(self.in_channels) | |||||
assert not self.num_outs < len(in_channels) | |||||
self.lateral_convs_list_ = [] | |||||
self.fpn_convs_ = [] | |||||
for _, channel in enumerate(in_channels): | |||||
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid') | |||||
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same') | |||||
self.lateral_convs_list_.append(l_conv) | |||||
self.fpn_convs_.append(fpn_conv) | |||||
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | |||||
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | |||||
self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) | |||||
self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) | |||||
self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) | |||||
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") | |||||
def construct(self, inputs): | |||||
x = () | |||||
for i in range(self.fpn_layer): | |||||
x += (self.lateral_convs_list[i](inputs[i]),) | |||||
y = (x[3],) | |||||
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),) | |||||
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),) | |||||
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),) | |||||
z = () | |||||
for i in range(self.fpn_layer - 1, -1, -1): | |||||
z = z + (y[i],) | |||||
outs = () | |||||
for i in range(self.fpn_layer): | |||||
outs = outs + (self.fpn_convs_list[i](z[i]),) | |||||
for i in range(self.num_outs - self.fpn_layer): | |||||
outs = outs + (self.maxpool(outs[3]),) | |||||
return outs |
@@ -0,0 +1,201 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn proposal generator.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.ops import operations as P | |||||
from mindspore import Tensor | |||||
from mindspore import context | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
class Proposal(nn.Cell): | |||||
""" | |||||
Proposal subnet. | |||||
Args: | |||||
config (dict): Config. | |||||
batch_size (int): Batchsize. | |||||
num_classes (int) - Class number. | |||||
use_sigmoid_cls (bool) - Select sigmoid or softmax function. | |||||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0). | |||||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0). | |||||
Returns: | |||||
Tuple, tuple of output tensor,(proposal, mask). | |||||
Examples: | |||||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \ | |||||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)) | |||||
""" | |||||
def __init__(self, | |||||
config, | |||||
batch_size, | |||||
num_classes, | |||||
use_sigmoid_cls, | |||||
target_means=(.0, .0, .0, .0), | |||||
target_stds=(1.0, 1.0, 1.0, 1.0) | |||||
): | |||||
super(Proposal, self).__init__() | |||||
cfg = config | |||||
self.batch_size = batch_size | |||||
self.num_classes = num_classes | |||||
self.target_means = target_means | |||||
self.target_stds = target_stds | |||||
self.use_sigmoid_cls = use_sigmoid_cls | |||||
if self.use_sigmoid_cls: | |||||
self.cls_out_channels = num_classes - 1 | |||||
self.activation = P.Sigmoid() | |||||
self.reshape_shape = (-1, 1) | |||||
else: | |||||
self.cls_out_channels = num_classes | |||||
self.activation = P.Softmax(axis=1) | |||||
self.reshape_shape = (-1, 2) | |||||
if self.cls_out_channels <= 0: | |||||
raise ValueError('num_classes={} is too small'.format(num_classes)) | |||||
self.num_pre = cfg.rpn_proposal_nms_pre | |||||
self.min_box_size = cfg.rpn_proposal_min_bbox_size | |||||
self.nms_thr = cfg.rpn_proposal_nms_thr | |||||
self.nms_post = cfg.rpn_proposal_nms_post | |||||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels | |||||
self.max_num = cfg.rpn_proposal_max_num | |||||
self.num_levels = cfg.fpn_num_outs | |||||
# Op Define | |||||
self.squeeze = P.Squeeze() | |||||
self.reshape = P.Reshape() | |||||
self.cast = P.Cast() | |||||
self.feature_shapes = cfg.feature_shapes | |||||
self.transpose_shape = (1, 2, 0) | |||||
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \ | |||||
means=self.target_means, \ | |||||
stds=self.target_stds) | |||||
self.nms = P.NMSWithMask(self.nms_thr) | |||||
self.concat_axis0 = P.Concat(axis=0) | |||||
self.concat_axis1 = P.Concat(axis=1) | |||||
self.split = P.Split(axis=1, output_num=5) | |||||
self.min = P.Minimum() | |||||
self.gatherND = P.GatherNd() | |||||
self.slice = P.Slice() | |||||
self.select = P.Select() | |||||
self.greater = P.Greater() | |||||
self.transpose = P.Transpose() | |||||
self.tile = P.Tile() | |||||
self.set_train_local(config, training=True) | |||||
self.multi_10 = Tensor(10.0, mstype.float16) | |||||
def set_train_local(self, config, training=True): | |||||
"""Set training flag.""" | |||||
self.training_local = training | |||||
cfg = config | |||||
self.topK_stage1 = () | |||||
self.topK_shape = () | |||||
total_max_topk_input = 0 | |||||
if not self.training_local: | |||||
self.num_pre = cfg.rpn_nms_pre | |||||
self.min_box_size = cfg.rpn_min_bbox_min_size | |||||
self.nms_thr = cfg.rpn_nms_thr | |||||
self.nms_post = cfg.rpn_nms_post | |||||
self.nms_across_levels = cfg.rpn_nms_across_levels | |||||
self.max_num = cfg.rpn_max_num | |||||
for shp in self.feature_shapes: | |||||
k_num = min(self.num_pre, (shp[0] * shp[1] * 3)) | |||||
total_max_topk_input += k_num | |||||
self.topK_stage1 += (k_num,) | |||||
self.topK_shape += ((k_num, 1),) | |||||
self.topKv2 = P.TopK(sorted=True) | |||||
self.topK_shape_stage2 = (self.max_num, 1) | |||||
self.min_float_num = -65536.0 | |||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) | |||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): | |||||
proposals_tuple = () | |||||
masks_tuple = () | |||||
for img_id in range(self.batch_size): | |||||
cls_score_list = () | |||||
bbox_pred_list = () | |||||
for i in range(self.num_levels): | |||||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::]) | |||||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::]) | |||||
cls_score_list = cls_score_list + (rpn_cls_score_i,) | |||||
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,) | |||||
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list) | |||||
proposals_tuple += (proposals,) | |||||
masks_tuple += (masks,) | |||||
return proposals_tuple, masks_tuple | |||||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors): | |||||
"""Get proposal boundingbox.""" | |||||
mlvl_proposals = () | |||||
mlvl_mask = () | |||||
for idx in range(self.num_levels): | |||||
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape) | |||||
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape) | |||||
anchors = mlvl_anchors[idx] | |||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) | |||||
rpn_cls_score = self.activation(rpn_cls_score) | |||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16) | |||||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) | |||||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx]) | |||||
topk_inds = self.reshape(topk_inds, self.topK_shape[idx]) | |||||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) | |||||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) | |||||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted) | |||||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx]))) | |||||
proposals, _, mask_valid = self.nms(proposals_decode) | |||||
mlvl_proposals = mlvl_proposals + (proposals,) | |||||
mlvl_mask = mlvl_mask + (mask_valid,) | |||||
proposals = self.concat_axis0(mlvl_proposals) | |||||
masks = self.concat_axis0(mlvl_mask) | |||||
_, _, _, _, scores = self.split(proposals) | |||||
scores = self.squeeze(scores) | |||||
topk_mask = self.cast(self.topK_mask, mstype.float16) | |||||
scores_using = self.select(masks, scores, topk_mask) | |||||
_, topk_inds = self.topKv2(scores_using, self.max_num) | |||||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2) | |||||
proposals = self.gatherND(proposals, topk_inds) | |||||
masks = self.gatherND(masks, topk_inds) | |||||
return proposals, masks |
@@ -0,0 +1,173 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn Rcnn network.""" | |||||
import numpy as np | |||||
import mindspore.common.dtype as mstype | |||||
import mindspore.nn as nn | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
from mindspore.common.initializer import initializer | |||||
from mindspore.common.parameter import Parameter | |||||
# pylint: disable=locally-disabled, missing-docstring | |||||
class DenseNoTranpose(nn.Cell): | |||||
"""Dense method""" | |||||
def __init__(self, input_channels, output_channels, weight_init): | |||||
super(DenseNoTranpose, self).__init__() | |||||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16), | |||||
name="weight") | |||||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias") | |||||
self.matmul = P.MatMul(transpose_b=False) | |||||
self.bias_add = P.BiasAdd() | |||||
def construct(self, x): | |||||
output = self.bias_add(self.matmul(x, self.weight), self.bias) | |||||
return output | |||||
class Rcnn(nn.Cell): | |||||
""" | |||||
Rcnn subnet. | |||||
Args: | |||||
config (dict) - Config. | |||||
representation_size (int) - Channels of shared dense. | |||||
batch_size (int) - Batchsize. | |||||
num_classes (int) - Class number. | |||||
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]). | |||||
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2). | |||||
Returns: | |||||
Tuple, tuple of output tensor. | |||||
Examples: | |||||
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \ | |||||
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2)) | |||||
""" | |||||
def __init__(self, | |||||
config, | |||||
representation_size, | |||||
batch_size, | |||||
num_classes, | |||||
target_means=(0., 0., 0., 0.), | |||||
target_stds=(0.1, 0.1, 0.2, 0.2) | |||||
): | |||||
super(Rcnn, self).__init__() | |||||
cfg = config | |||||
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16)) | |||||
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16)) | |||||
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels | |||||
self.target_means = target_means | |||||
self.target_stds = target_stds | |||||
self.num_classes = num_classes | |||||
self.in_channels = cfg.rcnn_in_channels | |||||
self.train_batch_size = batch_size | |||||
self.test_batch_size = cfg.test_batch_size | |||||
shape_0 = (self.rcnn_fc_out_channels, representation_size) | |||||
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() | |||||
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels) | |||||
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() | |||||
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0) | |||||
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1) | |||||
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1], | |||||
dtype=mstype.float16).to_tensor() | |||||
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1], | |||||
dtype=mstype.float16).to_tensor() | |||||
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight) | |||||
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight) | |||||
self.flatten = P.Flatten() | |||||
self.relu = P.ReLU() | |||||
self.logicaland = P.LogicalAnd() | |||||
self.loss_cls = P.SoftmaxCrossEntropyWithLogits() | |||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0) | |||||
self.reshape = P.Reshape() | |||||
self.onehot = P.OneHot() | |||||
self.greater = P.Greater() | |||||
self.cast = P.Cast() | |||||
self.sum_loss = P.ReduceSum() | |||||
self.tile = P.Tile() | |||||
self.expandims = P.ExpandDims() | |||||
self.gather = P.GatherNd() | |||||
self.argmax = P.ArgMaxWithValue(axis=1) | |||||
self.on_value = Tensor(1.0, mstype.float32) | |||||
self.off_value = Tensor(0.0, mstype.float32) | |||||
self.value = Tensor(1.0, mstype.float16) | |||||
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size | |||||
rmv_first = np.ones((self.num_bboxes, self.num_classes)) | |||||
rmv_first[:, 0] = np.zeros((self.num_bboxes,)) | |||||
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) | |||||
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size | |||||
range_max = np.arange(self.num_bboxes_test).astype(np.int32) | |||||
self.range_max = Tensor(range_max) | |||||
def construct(self, featuremap, bbox_targets, labels, mask): | |||||
x = self.flatten(featuremap) | |||||
x = self.relu(self.shared_fc_0(x)) | |||||
x = self.relu(self.shared_fc_1(x)) | |||||
x_cls = self.cls_scores(x) | |||||
x_reg = self.reg_scores(x) | |||||
if self.training: | |||||
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | |||||
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16) | |||||
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | |||||
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | |||||
out = (loss, loss_cls, loss_reg, loss_print) | |||||
else: | |||||
out = (x_cls, (x_cls / self.value), x_reg, x_cls) | |||||
return out | |||||
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights): | |||||
"""Loss method.""" | |||||
loss_print = () | |||||
loss_cls, _ = self.loss_cls(cls_score, labels) | |||||
weights = self.cast(weights, mstype.float16) | |||||
loss_cls = loss_cls * weights | |||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) | |||||
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), | |||||
mstype.float16) | |||||
bbox_weights = bbox_weights * self.rmv_first_tensor | |||||
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) | |||||
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets) | |||||
loss_reg = self.sum_loss(loss_reg, (2,)) | |||||
loss_reg = loss_reg * bbox_weights | |||||
loss_reg = loss_reg / self.sum_loss(weights, (0,)) | |||||
loss_reg = self.sum_loss(loss_reg, (0, 1)) | |||||
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg | |||||
loss_print += (loss_cls, loss_reg) | |||||
return loss, loss_cls, loss_reg, loss_print |
@@ -0,0 +1,250 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""Resnet50 backbone.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
from mindspore.ops import operations as P | |||||
from mindspore.common.tensor import Tensor | |||||
from mindspore.ops import functional as F | |||||
from mindspore import context | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
def weight_init_ones(shape): | |||||
"""Weight init.""" | |||||
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16)) | |||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||||
"""Conv2D wrapper.""" | |||||
shape = (out_channels, in_channels, kernel_size, kernel_size) | |||||
weights = weight_init_ones(shape) | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
pad_mode=pad_mode, weight_init=weights, has_bias=False) | |||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True): | |||||
"""Batchnorm2D wrapper.""" | |||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, | |||||
beta_init=beta_init, moving_mean_init=moving_mean_init, | |||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) | |||||
class ResNetFea(nn.Cell): | |||||
""" | |||||
ResNet architecture. | |||||
Args: | |||||
block (Cell): Block for network. | |||||
layer_nums (list): Numbers of block in different layers. | |||||
in_channels (list): Input channel in each layer. | |||||
out_channels (list): Output channel in each layer. | |||||
weights_update (bool): Weight update flag. | |||||
Returns: | |||||
Tensor, output tensor. | |||||
Examples: | |||||
>>> ResNet(ResidualBlock, | |||||
>>> [3, 4, 6, 3], | |||||
>>> [64, 256, 512, 1024], | |||||
>>> [256, 512, 1024, 2048], | |||||
>>> False) | |||||
""" | |||||
def __init__(self, | |||||
block, | |||||
layer_nums, | |||||
in_channels, | |||||
out_channels, | |||||
weights_update=False): | |||||
super(ResNetFea, self).__init__() | |||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||||
raise ValueError("the length of " | |||||
"layer_num, inchannel, outchannel list must be 4!") | |||||
bn_training = False | |||||
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') | |||||
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training) | |||||
self.relu = P.ReLU() | |||||
self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME") | |||||
self.weights_update = weights_update | |||||
if not self.weights_update: | |||||
self.conv1.weight.requires_grad = False | |||||
self.layer1 = self._make_layer(block, | |||||
layer_nums[0], | |||||
in_channel=in_channels[0], | |||||
out_channel=out_channels[0], | |||||
stride=1, | |||||
training=bn_training, | |||||
weights_update=self.weights_update) | |||||
self.layer2 = self._make_layer(block, | |||||
layer_nums[1], | |||||
in_channel=in_channels[1], | |||||
out_channel=out_channels[1], | |||||
stride=2, | |||||
training=bn_training, | |||||
weights_update=True) | |||||
self.layer3 = self._make_layer(block, | |||||
layer_nums[2], | |||||
in_channel=in_channels[2], | |||||
out_channel=out_channels[2], | |||||
stride=2, | |||||
training=bn_training, | |||||
weights_update=True) | |||||
self.layer4 = self._make_layer(block, | |||||
layer_nums[3], | |||||
in_channel=in_channels[3], | |||||
out_channel=out_channels[3], | |||||
stride=2, | |||||
training=bn_training, | |||||
weights_update=True) | |||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False): | |||||
"""Make block layer.""" | |||||
layers = [] | |||||
down_sample = False | |||||
if stride != 1 or in_channel != out_channel: | |||||
down_sample = True | |||||
resblk = block(in_channel, | |||||
out_channel, | |||||
stride=stride, | |||||
down_sample=down_sample, | |||||
training=training, | |||||
weights_update=weights_update) | |||||
layers.append(resblk) | |||||
for _ in range(1, layer_num): | |||||
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update) | |||||
layers.append(resblk) | |||||
return nn.SequentialCell(layers) | |||||
def construct(self, x): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
c1 = self.maxpool(x) | |||||
c2 = self.layer1(c1) | |||||
identity = c2 | |||||
if not self.weights_update: | |||||
identity = F.stop_gradient(c2) | |||||
c3 = self.layer2(identity) | |||||
c4 = self.layer3(c3) | |||||
c5 = self.layer4(c4) | |||||
return identity, c3, c4, c5 | |||||
class ResidualBlockUsing(nn.Cell): | |||||
""" | |||||
ResNet V1 residual block definition. | |||||
Args: | |||||
in_channels (int) - Input channel. | |||||
out_channels (int) - Output channel. | |||||
stride (int) - Stride size for the initial convolutional layer. Default: 1. | |||||
down_sample (bool) - If to do the downsample in block. Default: False. | |||||
momentum (float) - Momentum for batchnorm layer. Default: 0.1. | |||||
training (bool) - Training flag. Default: False. | |||||
weights_updata (bool) - Weights update flag. Default: False. | |||||
Returns: | |||||
Tensor, output tensor. | |||||
Examples: | |||||
ResidualBlock(3,256,stride=2,down_sample=True) | |||||
""" | |||||
expansion = 4 | |||||
def __init__(self, | |||||
in_channels, | |||||
out_channels, | |||||
stride=1, | |||||
down_sample=False, | |||||
momentum=0.1, | |||||
training=False, | |||||
weights_update=False): | |||||
super(ResidualBlockUsing, self).__init__() | |||||
self.affine = weights_update | |||||
out_chls = out_channels // self.expansion | |||||
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0) | |||||
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||||
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1) | |||||
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||||
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0) | |||||
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||||
if training: | |||||
self.bn1 = self.bn1.set_train() | |||||
self.bn2 = self.bn2.set_train() | |||||
self.bn3 = self.bn3.set_train() | |||||
if not weights_update: | |||||
self.conv1.weight.requires_grad = False | |||||
self.conv2.weight.requires_grad = False | |||||
self.conv3.weight.requires_grad = False | |||||
self.relu = P.ReLU() | |||||
self.downsample = down_sample | |||||
if self.downsample: | |||||
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) | |||||
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, | |||||
use_batch_statistics=training) | |||||
if training: | |||||
self.bn_down_sample = self.bn_down_sample.set_train() | |||||
if not weights_update: | |||||
self.conv_down_sample.weight.requires_grad = False | |||||
self.add = P.TensorAdd() | |||||
def construct(self, x): | |||||
identity = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample: | |||||
identity = self.conv_down_sample(identity) | |||||
identity = self.bn_down_sample(identity) | |||||
out = self.add(out, identity) | |||||
out = self.relu(out) | |||||
return out |
@@ -0,0 +1,181 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn ROIAlign module.""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.ops import operations as P | |||||
from mindspore.ops import composite as C | |||||
from mindspore.nn import layer as L | |||||
from mindspore.common.tensor import Tensor | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
class ROIAlign(nn.Cell): | |||||
""" | |||||
Extract RoI features from mulitple feature map. | |||||
Args: | |||||
out_size_h (int) - RoI height. | |||||
out_size_w (int) - RoI width. | |||||
spatial_scale (int) - RoI spatial scale. | |||||
sample_num (int) - RoI sample number. | |||||
""" | |||||
def __init__(self, | |||||
out_size_h, | |||||
out_size_w, | |||||
spatial_scale, | |||||
sample_num=0): | |||||
super(ROIAlign, self).__init__() | |||||
self.out_size = (out_size_h, out_size_w) | |||||
self.spatial_scale = float(spatial_scale) | |||||
self.sample_num = int(sample_num) | |||||
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1], | |||||
self.spatial_scale, self.sample_num) | |||||
def construct(self, features, rois): | |||||
return self.align_op(features, rois) | |||||
def __repr__(self): | |||||
format_str = self.__class__.__name__ | |||||
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format( | |||||
self.out_size, self.spatial_scale, self.sample_num) | |||||
return format_str | |||||
class SingleRoIExtractor(nn.Cell): | |||||
""" | |||||
Extract RoI features from a single level feature map. | |||||
If there are mulitple input feature levels, each RoI is mapped to a level | |||||
according to its scale. | |||||
Args: | |||||
config (dict): Config | |||||
roi_layer (dict): Specify RoI layer type and arguments. | |||||
out_channels (int): Output channels of RoI layers. | |||||
featmap_strides (int): Strides of input feature maps. | |||||
batch_size (int): Batchsize. | |||||
finest_scale (int): Scale threshold of mapping to level 0. | |||||
""" | |||||
def __init__(self, | |||||
config, | |||||
roi_layer, | |||||
out_channels, | |||||
featmap_strides, | |||||
batch_size=1, | |||||
finest_scale=56): | |||||
super(SingleRoIExtractor, self).__init__() | |||||
cfg = config | |||||
self.train_batch_size = batch_size | |||||
self.out_channels = out_channels | |||||
self.featmap_strides = featmap_strides | |||||
self.num_levels = len(self.featmap_strides) | |||||
self.out_size = roi_layer['out_size'] | |||||
self.sample_num = roi_layer['sample_num'] | |||||
self.roi_layers = self.build_roi_layers(self.featmap_strides) | |||||
self.roi_layers = L.CellList(self.roi_layers) | |||||
self.sqrt = P.Sqrt() | |||||
self.log = P.Log() | |||||
self.finest_scale_ = finest_scale | |||||
self.clamp = C.clip_by_value | |||||
self.cast = P.Cast() | |||||
self.equal = P.Equal() | |||||
self.select = P.Select() | |||||
_mode_16 = False | |||||
self.dtype = np.float16 if _mode_16 else np.float32 | |||||
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32 | |||||
self.set_train_local(cfg, training=True) | |||||
def set_train_local(self, config, training=True): | |||||
"""Set training flag.""" | |||||
self.training_local = training | |||||
cfg = config | |||||
# Init tensor | |||||
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num | |||||
self.batch_size = self.train_batch_size*self.batch_size \ | |||||
if self.training_local else cfg.test_batch_size*self.batch_size | |||||
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)) | |||||
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_ | |||||
self.finest_scale = Tensor(finest_scale) | |||||
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6)) | |||||
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32)) | |||||
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1)) | |||||
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2) | |||||
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels, | |||||
self.out_size, self.out_size)), dtype=self.dtype)) | |||||
def num_inputs(self): | |||||
return len(self.featmap_strides) | |||||
def init_weights(self): | |||||
pass | |||||
def log2(self, value): | |||||
return self.log(value) / self.log(self.twos) | |||||
def build_roi_layers(self, featmap_strides): | |||||
roi_layers = [] | |||||
for s in featmap_strides: | |||||
layer_cls = ROIAlign(self.out_size, self.out_size, | |||||
spatial_scale=1 / s, | |||||
sample_num=self.sample_num) | |||||
roi_layers.append(layer_cls) | |||||
return roi_layers | |||||
def _c_map_roi_levels(self, rois): | |||||
"""Map rois to corresponding feature levels by scales. | |||||
- scale < finest_scale * 2: level 0 | |||||
- finest_scale * 2 <= scale < finest_scale * 4: level 1 | |||||
- finest_scale * 4 <= scale < finest_scale * 8: level 2 | |||||
- scale >= finest_scale * 8: level 3 | |||||
Args: | |||||
rois (Tensor): Input RoIs, shape (k, 5). | |||||
num_levels (int): Total level number. | |||||
Returns: | |||||
Tensor: Level index (0-based) of each RoI, shape (k, ) | |||||
""" | |||||
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \ | |||||
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones) | |||||
target_lvls = self.log2(scale / self.finest_scale + self.epslion) | |||||
target_lvls = P.Floor()(target_lvls) | |||||
target_lvls = self.cast(target_lvls, mstype.int32) | |||||
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels) | |||||
return target_lvls | |||||
def construct(self, rois, feat1, feat2, feat3, feat4): | |||||
feats = (feat1, feat2, feat3, feat4) | |||||
res = self.res_ | |||||
target_lvls = self._c_map_roi_levels(rois) | |||||
for i in range(self.num_levels): | |||||
mask = self.equal(target_lvls, P.ScalarToArray()(i)) | |||||
mask = P.Reshape()(mask, (-1, 1, 1, 1)) | |||||
roi_feats_t = self.roi_layers[i](feats[i], rois) | |||||
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_) | |||||
res = self.select(mask, roi_feats_t, res) | |||||
return res |
@@ -0,0 +1,315 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""RPN for fasterRCNN""" | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.ops import operations as P | |||||
from mindspore import Tensor | |||||
from mindspore.ops import functional as F | |||||
from mindspore.common.initializer import initializer | |||||
from .bbox_assign_sample import BboxAssignSample | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||||
class RpnRegClsBlock(nn.Cell): | |||||
""" | |||||
Rpn reg cls block for rpn layer | |||||
Args: | |||||
in_channels (int) - Input channels of shared convolution. | |||||
feat_channels (int) - Output channels of shared convolution. | |||||
num_anchors (int) - The anchor number. | |||||
cls_out_channels (int) - Output channels of classification convolution. | |||||
weight_conv (Tensor) - weight init for rpn conv. | |||||
bias_conv (Tensor) - bias init for rpn conv. | |||||
weight_cls (Tensor) - weight init for rpn cls conv. | |||||
bias_cls (Tensor) - bias init for rpn cls conv. | |||||
weight_reg (Tensor) - weight init for rpn reg conv. | |||||
bias_reg (Tensor) - bias init for rpn reg conv. | |||||
Returns: | |||||
Tensor, output tensor. | |||||
""" | |||||
def __init__(self, | |||||
in_channels, | |||||
feat_channels, | |||||
num_anchors, | |||||
cls_out_channels, | |||||
weight_conv, | |||||
bias_conv, | |||||
weight_cls, | |||||
bias_cls, | |||||
weight_reg, | |||||
bias_reg): | |||||
super(RpnRegClsBlock, self).__init__() | |||||
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same', | |||||
has_bias=True, weight_init=weight_conv, bias_init=bias_conv) | |||||
self.relu = nn.ReLU() | |||||
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid', | |||||
has_bias=True, weight_init=weight_cls, bias_init=bias_cls) | |||||
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid', | |||||
has_bias=True, weight_init=weight_reg, bias_init=bias_reg) | |||||
def construct(self, x): | |||||
x = self.relu(self.rpn_conv(x)) | |||||
x1 = self.rpn_cls(x) | |||||
x2 = self.rpn_reg(x) | |||||
return x1, x2 | |||||
class RPN(nn.Cell): | |||||
""" | |||||
ROI proposal network.. | |||||
Args: | |||||
config (dict) - Config. | |||||
batch_size (int) - Batchsize. | |||||
in_channels (int) - Input channels of shared convolution. | |||||
feat_channels (int) - Output channels of shared convolution. | |||||
num_anchors (int) - The anchor number. | |||||
cls_out_channels (int) - Output channels of classification convolution. | |||||
Returns: | |||||
Tuple, tuple of output tensor. | |||||
Examples: | |||||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, | |||||
num_anchors=3, cls_out_channels=512) | |||||
""" | |||||
def __init__(self, | |||||
config, | |||||
batch_size, | |||||
in_channels, | |||||
feat_channels, | |||||
num_anchors, | |||||
cls_out_channels): | |||||
super(RPN, self).__init__() | |||||
cfg_rpn = config | |||||
self.num_bboxes = cfg_rpn.num_bboxes | |||||
self.slice_index = () | |||||
self.feature_anchor_shape = () | |||||
self.slice_index += (0,) | |||||
index = 0 | |||||
for shape in cfg_rpn.feature_shapes: | |||||
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,) | |||||
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,) | |||||
index += 1 | |||||
self.num_anchors = num_anchors | |||||
self.batch_size = batch_size | |||||
self.test_batch_size = cfg_rpn.test_batch_size | |||||
self.num_layers = 5 | |||||
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) | |||||
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, | |||||
num_anchors, cls_out_channels)) | |||||
self.transpose = P.Transpose() | |||||
self.reshape = P.Reshape() | |||||
self.concat = P.Concat(axis=0) | |||||
self.fill = P.Fill() | |||||
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) | |||||
self.trans_shape = (0, 2, 3, 1) | |||||
self.reshape_shape_reg = (-1, 4) | |||||
self.reshape_shape_cls = (-1,) | |||||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) | |||||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) | |||||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) | |||||
self.num_bboxes = cfg_rpn.num_bboxes | |||||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) | |||||
self.CheckValid = P.CheckValid() | |||||
self.sum_loss = P.ReduceSum() | |||||
self.loss_cls = P.SigmoidCrossEntropyWithLogits() | |||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0) | |||||
self.squeeze = P.Squeeze() | |||||
self.cast = P.Cast() | |||||
self.tile = P.Tile() | |||||
self.zeros_like = P.ZerosLike() | |||||
self.loss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): | |||||
""" | |||||
make rpn layer for rpn proposal network | |||||
Args: | |||||
num_layers (int) - layer num. | |||||
in_channels (int) - Input channels of shared convolution. | |||||
feat_channels (int) - Output channels of shared convolution. | |||||
num_anchors (int) - The anchor number. | |||||
cls_out_channels (int) - Output channels of classification convolution. | |||||
Returns: | |||||
List, list of RpnRegClsBlock cells. | |||||
""" | |||||
rpn_layer = [] | |||||
shp_weight_conv = (feat_channels, in_channels, 3, 3) | |||||
shp_bias_conv = (feat_channels,) | |||||
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() | |||||
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() | |||||
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | |||||
shp_bias_cls = (num_anchors * cls_out_channels,) | |||||
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() | |||||
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() | |||||
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | |||||
shp_bias_reg = (num_anchors * 4,) | |||||
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() | |||||
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() | |||||
for i in range(num_layers): | |||||
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||||
weight_conv, bias_conv, weight_cls, \ | |||||
bias_cls, weight_reg, bias_reg)) | |||||
for i in range(1, num_layers): | |||||
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight | |||||
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight | |||||
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight | |||||
rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias | |||||
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias | |||||
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias | |||||
return rpn_layer | |||||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): | |||||
loss_print = () | |||||
rpn_cls_score = () | |||||
rpn_bbox_pred = () | |||||
rpn_cls_score_total = () | |||||
rpn_bbox_pred_total = () | |||||
for i in range(self.num_layers): | |||||
x1, x2 = self.rpn_convs_list[i](inputs[i]) | |||||
rpn_cls_score_total = rpn_cls_score_total + (x1,) | |||||
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,) | |||||
x1 = self.transpose(x1, self.trans_shape) | |||||
x1 = self.reshape(x1, self.reshape_shape_cls) | |||||
x2 = self.transpose(x2, self.trans_shape) | |||||
x2 = self.reshape(x2, self.reshape_shape_reg) | |||||
rpn_cls_score = rpn_cls_score + (x1,) | |||||
rpn_bbox_pred = rpn_bbox_pred + (x2,) | |||||
loss = self.loss | |||||
clsloss = self.clsloss | |||||
regloss = self.regloss | |||||
bbox_targets = () | |||||
bbox_weights = () | |||||
labels = () | |||||
label_weights = () | |||||
output = () | |||||
if self.training: | |||||
for i in range(self.batch_size): | |||||
multi_level_flags = () | |||||
anchor_list_tuple = () | |||||
for j in range(self.num_layers): | |||||
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])), | |||||
mstype.int32) | |||||
multi_level_flags = multi_level_flags + (res,) | |||||
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],) | |||||
valid_flag_list = self.concat(multi_level_flags) | |||||
anchor_using_list = self.concat(anchor_list_tuple) | |||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) | |||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) | |||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) | |||||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, | |||||
gt_labels_i, | |||||
self.cast(valid_flag_list, | |||||
mstype.bool_), | |||||
anchor_using_list, gt_valids_i) | |||||
bbox_weight = self.cast(bbox_weight, mstype.float16) | |||||
label = self.cast(label, mstype.float16) | |||||
label_weight = self.cast(label_weight, mstype.float16) | |||||
for j in range(self.num_layers): | |||||
begin = self.slice_index[j] | |||||
end = self.slice_index[j + 1] | |||||
stride = 1 | |||||
bbox_targets += (bbox_target[begin:end:stride, ::],) | |||||
bbox_weights += (bbox_weight[begin:end:stride],) | |||||
labels += (label[begin:end:stride],) | |||||
label_weights += (label_weight[begin:end:stride],) | |||||
for i in range(self.num_layers): | |||||
bbox_target_using = () | |||||
bbox_weight_using = () | |||||
label_using = () | |||||
label_weight_using = () | |||||
for j in range(self.batch_size): | |||||
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],) | |||||
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],) | |||||
label_using += (labels[i + (self.num_layers * j)],) | |||||
label_weight_using += (label_weights[i + (self.num_layers * j)],) | |||||
bbox_target_with_batchsize = self.concat(bbox_target_using) | |||||
bbox_weight_with_batchsize = self.concat(bbox_weight_using) | |||||
label_with_batchsize = self.concat(label_using) | |||||
label_weight_with_batchsize = self.concat(label_weight_using) | |||||
# stop | |||||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) | |||||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) | |||||
label_ = F.stop_gradient(label_with_batchsize) | |||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize) | |||||
cls_score_i = rpn_cls_score[i] | |||||
reg_score_i = rpn_bbox_pred[i] | |||||
loss_cls = self.loss_cls(cls_score_i, label_) | |||||
loss_cls_item = loss_cls * label_weight_ | |||||
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total | |||||
loss_reg = self.loss_bbox(reg_score_i, bbox_target_) | |||||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4)) | |||||
loss_reg = loss_reg * bbox_weight_ | |||||
loss_reg_item = self.sum_loss(loss_reg, (1,)) | |||||
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total | |||||
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item | |||||
loss += loss_total | |||||
loss_print += (loss_total, loss_cls_item, loss_reg_item) | |||||
clsloss += loss_cls_item | |||||
regloss += loss_reg_item | |||||
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print) | |||||
else: | |||||
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1) | |||||
return output |
@@ -0,0 +1,158 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# =========================================================================== | |||||
""" | |||||
network config setting, will be used in train.py and eval.py | |||||
""" | |||||
from easydict import EasyDict as ed | |||||
config = ed({ | |||||
"img_width": 1280, | |||||
"img_height": 768, | |||||
"keep_ratio": False, | |||||
"flip_ratio": 0.5, | |||||
"photo_ratio": 0.5, | |||||
"expand_ratio": 1.0, | |||||
# anchor | |||||
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)], | |||||
"anchor_scales": [8], | |||||
"anchor_ratios": [0.5, 1.0, 2.0], | |||||
"anchor_strides": [4, 8, 16, 32, 64], | |||||
"num_anchors": 3, | |||||
# resnet | |||||
"resnet_block": [3, 4, 6, 3], | |||||
"resnet_in_channels": [64, 256, 512, 1024], | |||||
"resnet_out_channels": [256, 512, 1024, 2048], | |||||
# fpn | |||||
"fpn_in_channels": [256, 512, 1024, 2048], | |||||
"fpn_out_channels": 256, | |||||
"fpn_num_outs": 5, | |||||
# rpn | |||||
"rpn_in_channels": 256, | |||||
"rpn_feat_channels": 256, | |||||
"rpn_loss_cls_weight": 1.0, | |||||
"rpn_loss_reg_weight": 1.0, | |||||
"rpn_cls_out_channels": 1, | |||||
"rpn_target_means": [0., 0., 0., 0.], | |||||
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0], | |||||
# bbox_assign_sampler | |||||
"neg_iou_thr": 0.3, | |||||
"pos_iou_thr": 0.7, | |||||
"min_pos_iou": 0.3, | |||||
"num_bboxes": 245520, | |||||
"num_gts": 128, | |||||
"num_expected_neg": 256, | |||||
"num_expected_pos": 128, | |||||
# proposal | |||||
"activate_num_classes": 2, | |||||
"use_sigmoid_cls": True, | |||||
# roi_align | |||||
"roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2), | |||||
"roi_align_out_channels": 256, | |||||
"roi_align_featmap_strides": [4, 8, 16, 32], | |||||
"roi_align_finest_scale": 56, | |||||
"roi_sample_num": 640, | |||||
# bbox_assign_sampler_stage2 | |||||
"neg_iou_thr_stage2": 0.5, | |||||
"pos_iou_thr_stage2": 0.5, | |||||
"min_pos_iou_stage2": 0.5, | |||||
"num_bboxes_stage2": 2000, | |||||
"num_expected_pos_stage2": 128, | |||||
"num_expected_neg_stage2": 512, | |||||
"num_expected_total_stage2": 512, | |||||
# rcnn | |||||
"rcnn_num_layers": 2, | |||||
"rcnn_in_channels": 256, | |||||
"rcnn_fc_out_channels": 1024, | |||||
"rcnn_loss_cls_weight": 1, | |||||
"rcnn_loss_reg_weight": 1, | |||||
"rcnn_target_means": [0., 0., 0., 0.], | |||||
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2], | |||||
# train proposal | |||||
"rpn_proposal_nms_across_levels": False, | |||||
"rpn_proposal_nms_pre": 2000, | |||||
"rpn_proposal_nms_post": 2000, | |||||
"rpn_proposal_max_num": 2000, | |||||
"rpn_proposal_nms_thr": 0.7, | |||||
"rpn_proposal_min_bbox_size": 0, | |||||
# test proposal | |||||
"rpn_nms_across_levels": False, | |||||
"rpn_nms_pre": 1000, | |||||
"rpn_nms_post": 1000, | |||||
"rpn_max_num": 1000, | |||||
"rpn_nms_thr": 0.7, | |||||
"rpn_min_bbox_min_size": 0, | |||||
"test_score_thr": 0.05, | |||||
"test_iou_thr": 0.5, | |||||
"test_max_per_img": 100, | |||||
"test_batch_size": 1, | |||||
"rpn_head_loss_type": "CrossEntropyLoss", | |||||
"rpn_head_use_sigmoid": True, | |||||
"rpn_head_weight": 1.0, | |||||
# LR | |||||
"base_lr": 0.02, | |||||
"base_step": 58633, | |||||
"total_epoch": 13, | |||||
"warmup_step": 500, | |||||
"warmup_mode": "linear", | |||||
"warmup_ratio": 1/3.0, | |||||
"sgd_step": [8, 11], | |||||
"sgd_momentum": 0.9, | |||||
# train | |||||
"batch_size": 1, | |||||
"loss_scale": 1, | |||||
"momentum": 0.91, | |||||
"weight_decay": 1e-4, | |||||
"epoch_size": 12, | |||||
"save_checkpoint": True, | |||||
"save_checkpoint_epochs": 1, | |||||
"keep_checkpoint_max": 10, | |||||
"save_checkpoint_path": "./", | |||||
"mindrecord_dir": "../MindRecord_COCO_TRAIN", | |||||
"coco_root": "./cocodataset/", | |||||
"train_data_type": "train2017", | |||||
"val_data_type": "val2017", | |||||
"instance_set": "annotations/instances_{}.json", | |||||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
'teddy bear', 'hair drier', 'toothbrush'), | |||||
"num_classes": 81 | |||||
}) |
@@ -0,0 +1,505 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn dataset""" | |||||
from __future__ import division | |||||
import os | |||||
import numpy as np | |||||
from numpy import random | |||||
import mmcv | |||||
import mindspore.dataset as de | |||||
import mindspore.dataset.vision.c_transforms as C | |||||
import mindspore.dataset.transforms.c_transforms as CC | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.mindrecord import FileWriter | |||||
from src.config import config | |||||
# pylint: disable=locally-disabled, unused-variable | |||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'): | |||||
"""Calculate the ious between each bbox of bboxes1 and bboxes2. | |||||
Args: | |||||
bboxes1(ndarray): shape (n, 4) | |||||
bboxes2(ndarray): shape (k, 4) | |||||
mode(str): iou (intersection over union) or iof (intersection | |||||
over foreground) | |||||
Returns: | |||||
ious(ndarray): shape (n, k) | |||||
""" | |||||
assert mode in ['iou', 'iof'] | |||||
bboxes1 = bboxes1.astype(np.float32) | |||||
bboxes2 = bboxes2.astype(np.float32) | |||||
rows = bboxes1.shape[0] | |||||
cols = bboxes2.shape[0] | |||||
ious = np.zeros((rows, cols), dtype=np.float32) | |||||
if rows * cols == 0: | |||||
return ious | |||||
exchange = False | |||||
if bboxes1.shape[0] > bboxes2.shape[0]: | |||||
bboxes1, bboxes2 = bboxes2, bboxes1 | |||||
ious = np.zeros((cols, rows), dtype=np.float32) | |||||
exchange = True | |||||
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1) | |||||
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1) | |||||
for i in range(bboxes1.shape[0]): | |||||
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) | |||||
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) | |||||
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) | |||||
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) | |||||
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( | |||||
y_end - y_start + 1, 0) | |||||
if mode == 'iou': | |||||
union = area1[i] + area2 - overlap | |||||
else: | |||||
union = area1[i] if not exchange else area2 | |||||
ious[i, :] = overlap / union | |||||
if exchange: | |||||
ious = ious.T | |||||
return ious | |||||
class PhotoMetricDistortion: | |||||
"""Photo Metric Distortion""" | |||||
def __init__(self, | |||||
brightness_delta=32, | |||||
contrast_range=(0.5, 1.5), | |||||
saturation_range=(0.5, 1.5), | |||||
hue_delta=18): | |||||
self.brightness_delta = brightness_delta | |||||
self.contrast_lower, self.contrast_upper = contrast_range | |||||
self.saturation_lower, self.saturation_upper = saturation_range | |||||
self.hue_delta = hue_delta | |||||
def __call__(self, img, boxes, labels): | |||||
# random brightness | |||||
img = img.astype('float32') | |||||
if random.randint(2): | |||||
delta = random.uniform(-self.brightness_delta, | |||||
self.brightness_delta) | |||||
img += delta | |||||
# mode == 0 --> do random contrast first | |||||
# mode == 1 --> do random contrast last | |||||
mode = random.randint(2) | |||||
if mode == 1: | |||||
if random.randint(2): | |||||
alpha = random.uniform(self.contrast_lower, | |||||
self.contrast_upper) | |||||
img *= alpha | |||||
# convert color from BGR to HSV | |||||
img = mmcv.bgr2hsv(img) | |||||
# random saturation | |||||
if random.randint(2): | |||||
img[..., 1] *= random.uniform(self.saturation_lower, | |||||
self.saturation_upper) | |||||
# random hue | |||||
if random.randint(2): | |||||
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) | |||||
img[..., 0][img[..., 0] > 360] -= 360 | |||||
img[..., 0][img[..., 0] < 0] += 360 | |||||
# convert color from HSV to BGR | |||||
img = mmcv.hsv2bgr(img) | |||||
# random contrast | |||||
if mode == 0: | |||||
if random.randint(2): | |||||
alpha = random.uniform(self.contrast_lower, | |||||
self.contrast_upper) | |||||
img *= alpha | |||||
# randomly swap channels | |||||
if random.randint(2): | |||||
img = img[..., random.permutation(3)] | |||||
return img, boxes, labels | |||||
class Expand: | |||||
"""expand image""" | |||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): | |||||
if to_rgb: | |||||
self.mean = mean[::-1] | |||||
else: | |||||
self.mean = mean | |||||
self.min_ratio, self.max_ratio = ratio_range | |||||
def __call__(self, img, boxes, labels): | |||||
if random.randint(2): | |||||
return img, boxes, labels | |||||
h, w, c = img.shape | |||||
ratio = random.uniform(self.min_ratio, self.max_ratio) | |||||
expand_img = np.full((int(h * ratio), int(w * ratio), c), | |||||
self.mean).astype(img.dtype) | |||||
left = int(random.uniform(0, w * ratio - w)) | |||||
top = int(random.uniform(0, h * ratio - h)) | |||||
expand_img[top:top + h, left:left + w] = img | |||||
img = expand_img | |||||
boxes += np.tile((left, top), 2) | |||||
return img, boxes, labels | |||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""rescale operation for image""" | |||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) | |||||
if img_data.shape[0] > config.img_height: | |||||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True) | |||||
scale_factor = scale_factor * scale_factor2 | |||||
img_shape = np.append(img_shape, scale_factor) | |||||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||||
gt_bboxes = gt_bboxes * scale_factor | |||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""resize operation for image""" | |||||
img_data = img | |||||
img_data, w_scale, h_scale = mmcv.imresize( | |||||
img_data, (config.img_width, config.img_height), return_scale=True) | |||||
scale_factor = np.array( | |||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||||
img_shape = (config.img_height, config.img_width, 1.0) | |||||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||||
gt_bboxes = gt_bboxes * scale_factor | |||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""resize operation for image of eval""" | |||||
img_data = img | |||||
img_data, w_scale, h_scale = mmcv.imresize( | |||||
img_data, (config.img_width, config.img_height), return_scale=True) | |||||
scale_factor = np.array( | |||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||||
img_shape = np.append(img_shape, (h_scale, w_scale)) | |||||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||||
gt_bboxes = gt_bboxes * scale_factor | |||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""impad operation for image""" | |||||
img_data = mmcv.impad(img, (config.img_height, config.img_width)) | |||||
img_data = img_data.astype(np.float32) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""imnormalize operation for image""" | |||||
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True) | |||||
img_data = img_data.astype(np.float32) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""flip operation for image""" | |||||
img_data = img | |||||
img_data = mmcv.imflip(img_data) | |||||
flipped = gt_bboxes.copy() | |||||
_, w, _ = img_data.shape | |||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 | |||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 | |||||
return (img_data, img_shape, flipped, gt_label, gt_num) | |||||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""flipped generation""" | |||||
img_data = img | |||||
flipped = gt_bboxes.copy() | |||||
_, w, _ = img_data.shape | |||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 | |||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 | |||||
return (img_data, img_shape, flipped, gt_label, gt_num) | |||||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
img_data = img[:, :, ::-1] | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""transpose operation for image""" | |||||
img_data = img.transpose(2, 0, 1).copy() | |||||
img_data = img_data.astype(np.float16) | |||||
img_shape = img_shape.astype(np.float16) | |||||
gt_bboxes = gt_bboxes.astype(np.float16) | |||||
gt_label = gt_label.astype(np.int32) | |||||
gt_num = gt_num.astype(np.bool) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""photo crop operation for image""" | |||||
random_photo = PhotoMetricDistortion() | |||||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) | |||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||||
"""expand operation for image""" | |||||
expand = Expand() | |||||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label) | |||||
return (img, img_shape, gt_bboxes, gt_label, gt_num) | |||||
def preprocess_fn(image, box, is_training): | |||||
"""Preprocess function for dataset.""" | |||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert): | |||||
image_shape = image_shape[:2] | |||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert | |||||
if config.keep_ratio: | |||||
input_data = rescale_column(*input_data) | |||||
else: | |||||
input_data = resize_column_test(*input_data) | |||||
input_data = image_bgr_rgb(*input_data) | |||||
output_data = input_data | |||||
return output_data | |||||
def _data_aug(image, box, is_training): | |||||
"""Data augmentation function.""" | |||||
image_bgr = image.copy() | |||||
image_bgr[:, :, 0] = image[:, :, 2] | |||||
image_bgr[:, :, 1] = image[:, :, 1] | |||||
image_bgr[:, :, 2] = image[:, :, 0] | |||||
image_shape = image_bgr.shape[:2] | |||||
gt_box = box[:, :4] | |||||
gt_label = box[:, 4] | |||||
gt_iscrowd = box[:, 5] | |||||
pad_max_number = 128 | |||||
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0) | |||||
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1) | |||||
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1) | |||||
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32) | |||||
if not is_training: | |||||
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert) | |||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert | |||||
if config.keep_ratio: | |||||
input_data = rescale_column(*input_data) | |||||
else: | |||||
input_data = resize_column(*input_data) | |||||
input_data = image_bgr_rgb(*input_data) | |||||
output_data = input_data | |||||
return output_data | |||||
return _data_aug(image, box, is_training) | |||||
def create_coco_label(is_training): | |||||
"""Get image path and annotation from COCO.""" | |||||
from pycocotools.coco import COCO | |||||
coco_root = config.coco_root | |||||
data_type = config.val_data_type | |||||
if is_training: | |||||
data_type = config.train_data_type | |||||
# Classes need to train or test. | |||||
train_cls = config.coco_classes | |||||
train_cls_dict = {} | |||||
for i, cls in enumerate(train_cls): | |||||
train_cls_dict[cls] = i | |||||
anno_json = os.path.join(coco_root, config.instance_set.format(data_type)) | |||||
coco = COCO(anno_json) | |||||
classs_dict = {} | |||||
cat_ids = coco.loadCats(coco.getCatIds()) | |||||
for cat in cat_ids: | |||||
classs_dict[cat["id"]] = cat["name"] | |||||
image_ids = coco.getImgIds() | |||||
image_files = [] | |||||
image_anno_dict = {} | |||||
for img_id in image_ids: | |||||
image_info = coco.loadImgs(img_id) | |||||
file_name = image_info[0]["file_name"] | |||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||||
anno = coco.loadAnns(anno_ids) | |||||
image_path = os.path.join(coco_root, data_type, file_name) | |||||
annos = [] | |||||
for label in anno: | |||||
bbox = label["bbox"] | |||||
class_name = classs_dict[label["category_id"]] | |||||
if class_name in train_cls: | |||||
x1, x2 = bbox[0], bbox[0] + bbox[2] | |||||
y1, y2 = bbox[1], bbox[1] + bbox[3] | |||||
annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])]) | |||||
image_files.append(image_path) | |||||
if annos: | |||||
image_anno_dict[image_path] = np.array(annos) | |||||
else: | |||||
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) | |||||
return image_files, image_anno_dict | |||||
def anno_parser(annos_str): | |||||
"""Parse annotation from string to list.""" | |||||
annos = [] | |||||
for anno_str in annos_str: | |||||
anno = list(map(int, anno_str.strip().split(','))) | |||||
annos.append(anno) | |||||
return annos | |||||
def filter_valid_data(image_dir, anno_path): | |||||
"""Filter valid image file, which both in image_dir and anno_path.""" | |||||
image_files = [] | |||||
image_anno_dict = {} | |||||
if not os.path.isdir(image_dir): | |||||
raise RuntimeError("Path given is not valid.") | |||||
if not os.path.isfile(anno_path): | |||||
raise RuntimeError("Annotation file is not valid.") | |||||
with open(anno_path, "rb") as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
line_str = line.decode("utf-8").strip() | |||||
line_split = str(line_str).split(' ') | |||||
file_name = line_split[0] | |||||
image_path = os.path.join(image_dir, file_name) | |||||
if os.path.isfile(image_path): | |||||
image_anno_dict[image_path] = anno_parser(line_split[1:]) | |||||
image_files.append(image_path) | |||||
return image_files, image_anno_dict | |||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8): | |||||
"""Create MindRecord file.""" | |||||
mindrecord_dir = config.mindrecord_dir | |||||
mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||||
writer = FileWriter(mindrecord_path, file_num) | |||||
if dataset == "coco": | |||||
image_files, image_anno_dict = create_coco_label(is_training) | |||||
else: | |||||
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) | |||||
fasterrcnn_json = { | |||||
"image": {"type": "bytes"}, | |||||
"annotation": {"type": "int32", "shape": [-1, 6]}, | |||||
} | |||||
writer.add_schema(fasterrcnn_json, "fasterrcnn_json") | |||||
for image_name in image_files: | |||||
with open(image_name, 'rb') as f: | |||||
img = f.read() | |||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||||
row = {"image": img, "annotation": annos} | |||||
writer.write_raw_data([row]) | |||||
writer.commit() | |||||
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, | |||||
is_training=True, num_parallel_workers=4): | |||||
"""Creatr FasterRcnn dataset with MindDataset.""" | |||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, | |||||
num_parallel_workers=1, shuffle=False) | |||||
decode = C.Decode() | |||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1) | |||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||||
hwc_to_chw = C.HWC2CHW() | |||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) | |||||
horizontally_op = C.RandomHorizontalFlip(1) | |||||
type_cast0 = CC.TypeCast(mstype.float32) | |||||
type_cast1 = CC.TypeCast(mstype.float16) | |||||
type_cast2 = CC.TypeCast(mstype.int32) | |||||
type_cast3 = CC.TypeCast(mstype.bool_) | |||||
if is_training: | |||||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], | |||||
output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||||
column_order=["image", "image_shape", "box", "label", "valid_num"], | |||||
num_parallel_workers=num_parallel_workers) | |||||
flip = (np.random.rand() < config.flip_ratio) | |||||
if flip: | |||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], | |||||
num_parallel_workers=12) | |||||
ds = ds.map(operations=flipped_generation, | |||||
input_columns=["image", "image_shape", "box", "label", "valid_num"], | |||||
num_parallel_workers=num_parallel_workers) | |||||
else: | |||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], | |||||
num_parallel_workers=12) | |||||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"], | |||||
num_parallel_workers=12) | |||||
else: | |||||
ds = ds.map(operations=compose_map_func, | |||||
input_columns=["image", "annotation"], | |||||
output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||||
column_order=["image", "image_shape", "box", "label", "valid_num"], | |||||
num_parallel_workers=num_parallel_workers) | |||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], | |||||
num_parallel_workers=24) | |||||
# transpose_column from python to c | |||||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) | |||||
ds = ds.map(operations=[type_cast1], input_columns=["box"]) | |||||
ds = ds.map(operations=[type_cast2], input_columns=["label"]) | |||||
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"]) | |||||
ds = ds.batch(batch_size, drop_remainder=True) | |||||
ds = ds.repeat(repeat_num) | |||||
return ds |
@@ -0,0 +1,42 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""lr generator for fasterrcnn""" | |||||
import math | |||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
learning_rate = float(init_lr) + lr_inc * current_step | |||||
return learning_rate | |||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
base = float(current_step - warmup_steps) / float(decay_steps) | |||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||||
return learning_rate | |||||
def dynamic_lr(config, rank_size=1): | |||||
"""dynamic learning rate generator""" | |||||
base_lr = config.base_lr | |||||
base_step = (config.base_step // rank_size) + rank_size | |||||
total_steps = int(base_step * config.total_epoch) | |||||
warmup_steps = int(config.warmup_step) | |||||
lr = [] | |||||
for i in range(total_steps): | |||||
if i < warmup_steps: | |||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | |||||
else: | |||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||||
return lr |
@@ -0,0 +1,184 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""FasterRcnn training network wrapper.""" | |||||
import time | |||||
import numpy as np | |||||
import mindspore.nn as nn | |||||
from mindspore.common.tensor import Tensor | |||||
from mindspore.ops import functional as F | |||||
from mindspore.ops import composite as C | |||||
from mindspore import ParameterTuple | |||||
from mindspore.train.callback import Callback | |||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
# pylint: disable=locally-disabled, missing-docstring, unused-argument | |||||
time_stamp_init = False | |||||
time_stamp_first = 0 | |||||
class LossCallBack(Callback): | |||||
""" | |||||
Monitor the loss in training. | |||||
If the loss is NAN or INF terminating training. | |||||
Note: | |||||
If per_print_times is 0 do not print loss. | |||||
Args: | |||||
per_print_times (int): Print loss every times. Default: 1. | |||||
""" | |||||
def __init__(self, per_print_times=1, rank_id=0): | |||||
super(LossCallBack, self).__init__() | |||||
if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
raise ValueError("print_step must be int and >= 0.") | |||||
self._per_print_times = per_print_times | |||||
self.count = 0 | |||||
self.rpn_loss_sum = 0 | |||||
self.rcnn_loss_sum = 0 | |||||
self.rpn_cls_loss_sum = 0 | |||||
self.rpn_reg_loss_sum = 0 | |||||
self.rcnn_cls_loss_sum = 0 | |||||
self.rcnn_reg_loss_sum = 0 | |||||
self.rank_id = rank_id | |||||
global time_stamp_init, time_stamp_first | |||||
if not time_stamp_init: | |||||
time_stamp_first = time.time() | |||||
time_stamp_init = True | |||||
def step_end(self, run_context): | |||||
cb_params = run_context.original_args() | |||||
rpn_loss = cb_params.net_outputs[0].asnumpy() | |||||
rcnn_loss = cb_params.net_outputs[1].asnumpy() | |||||
rpn_cls_loss = cb_params.net_outputs[2].asnumpy() | |||||
rpn_reg_loss = cb_params.net_outputs[3].asnumpy() | |||||
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy() | |||||
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy() | |||||
self.count += 1 | |||||
self.rpn_loss_sum += float(rpn_loss) | |||||
self.rcnn_loss_sum += float(rcnn_loss) | |||||
self.rpn_cls_loss_sum += float(rpn_cls_loss) | |||||
self.rpn_reg_loss_sum += float(rpn_reg_loss) | |||||
self.rcnn_cls_loss_sum += float(rcnn_cls_loss) | |||||
self.rcnn_reg_loss_sum += float(rcnn_reg_loss) | |||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
if self.count >= 1: | |||||
global time_stamp_first | |||||
time_stamp_current = time.time() | |||||
rpn_loss = self.rpn_loss_sum/self.count | |||||
rcnn_loss = self.rcnn_loss_sum/self.count | |||||
rpn_cls_loss = self.rpn_cls_loss_sum/self.count | |||||
rpn_reg_loss = self.rpn_reg_loss_sum/self.count | |||||
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count | |||||
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count | |||||
total_loss = rpn_loss + rcnn_loss | |||||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") | |||||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, " | |||||
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" % | |||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, | |||||
rcnn_cls_loss, rcnn_reg_loss, total_loss)) | |||||
loss_file.write("\n") | |||||
loss_file.close() | |||||
self.count = 0 | |||||
self.rpn_loss_sum = 0 | |||||
self.rcnn_loss_sum = 0 | |||||
self.rpn_cls_loss_sum = 0 | |||||
self.rpn_reg_loss_sum = 0 | |||||
self.rcnn_cls_loss_sum = 0 | |||||
self.rcnn_reg_loss_sum = 0 | |||||
class LossNet(nn.Cell): | |||||
"""FasterRcnn loss method""" | |||||
def construct(self, x1, x2, x3, x4, x5, x6): | |||||
return x1 + x2 | |||||
class WithLossCell(nn.Cell): | |||||
""" | |||||
Wrap the network with loss function to compute loss. | |||||
Args: | |||||
backbone (Cell): The target network to wrap. | |||||
loss_fn (Cell): The loss function used to compute loss. | |||||
""" | |||||
def __init__(self, backbone, loss_fn): | |||||
super(WithLossCell, self).__init__(auto_prefix=False) | |||||
self._backbone = backbone | |||||
self._loss_fn = loss_fn | |||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | |||||
@property | |||||
def backbone_network(self): | |||||
""" | |||||
Get the backbone network. | |||||
Returns: | |||||
Cell, return backbone network. | |||||
""" | |||||
return self._backbone | |||||
class TrainOneStepCell(nn.Cell): | |||||
""" | |||||
Network training package class. | |||||
Append an optimizer to the training network after that the construct function | |||||
can be called to create the backward graph. | |||||
Args: | |||||
network (Cell): The training network. | |||||
network_backbone (Cell): The forward network. | |||||
optimizer (Cell): Optimizer for updating the weights. | |||||
sens (Number): The adjust parameter. Default value is 1.0. | |||||
reduce_flag (bool): The reduce flag. Default value is False. | |||||
mean (bool): Allreduce method. Default value is False. | |||||
degree (int): Device number. Default value is None. | |||||
""" | |||||
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): | |||||
super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||||
self.network = network | |||||
self.network.set_grad() | |||||
self.backbone = network_backbone | |||||
self.weights = ParameterTuple(network.trainable_params()) | |||||
self.optimizer = optimizer | |||||
self.grad = C.GradOperation(get_by_list=True, | |||||
sens_param=True) | |||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) | |||||
self.reduce_flag = reduce_flag | |||||
if reduce_flag: | |||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||||
weights = self.weights | |||||
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens) | |||||
if self.reduce_flag: | |||||
grads = self.grad_reducer(grads) | |||||
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6 |
@@ -0,0 +1,227 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""coco eval for fasterrcnn""" | |||||
import json | |||||
import numpy as np | |||||
from pycocotools.coco import COCO | |||||
from pycocotools.cocoeval import COCOeval | |||||
import mmcv | |||||
# pylint: disable=locally-disabled, invalid-name | |||||
_init_value = np.array(0.0) | |||||
summary_init = { | |||||
'Precision/mAP': _init_value, | |||||
'Precision/mAP@.50IOU': _init_value, | |||||
'Precision/mAP@.75IOU': _init_value, | |||||
'Precision/mAP (small)': _init_value, | |||||
'Precision/mAP (medium)': _init_value, | |||||
'Precision/mAP (large)': _init_value, | |||||
'Recall/AR@1': _init_value, | |||||
'Recall/AR@10': _init_value, | |||||
'Recall/AR@100': _init_value, | |||||
'Recall/AR@100 (small)': _init_value, | |||||
'Recall/AR@100 (medium)': _init_value, | |||||
'Recall/AR@100 (large)': _init_value, | |||||
} | |||||
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): | |||||
"""coco eval for fasterrcnn""" | |||||
anns = json.load(open(result_files['bbox'])) | |||||
if not anns: | |||||
return summary_init | |||||
if mmcv.is_str(coco): | |||||
coco = COCO(coco) | |||||
assert isinstance(coco, COCO) | |||||
for res_type in result_types: | |||||
result_file = result_files[res_type] | |||||
assert result_file.endswith('.json') | |||||
coco_dets = coco.loadRes(result_file) | |||||
gt_img_ids = coco.getImgIds() | |||||
det_img_ids = coco_dets.getImgIds() | |||||
iou_type = 'bbox' if res_type == 'proposal' else res_type | |||||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||||
if res_type == 'proposal': | |||||
cocoEval.params.useCats = 0 | |||||
cocoEval.params.maxDets = list(max_dets) | |||||
tgt_ids = gt_img_ids if not single_result else det_img_ids | |||||
if single_result: | |||||
res_dict = dict() | |||||
for id_i in tgt_ids: | |||||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||||
if res_type == 'proposal': | |||||
cocoEval.params.useCats = 0 | |||||
cocoEval.params.maxDets = list(max_dets) | |||||
cocoEval.params.imgIds = [id_i] | |||||
cocoEval.evaluate() | |||||
cocoEval.accumulate() | |||||
cocoEval.summarize() | |||||
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]}) | |||||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||||
if res_type == 'proposal': | |||||
cocoEval.params.useCats = 0 | |||||
cocoEval.params.maxDets = list(max_dets) | |||||
cocoEval.params.imgIds = tgt_ids | |||||
cocoEval.evaluate() | |||||
cocoEval.accumulate() | |||||
cocoEval.summarize() | |||||
summary_metrics = { | |||||
'Precision/mAP': cocoEval.stats[0], | |||||
'Precision/mAP@.50IOU': cocoEval.stats[1], | |||||
'Precision/mAP@.75IOU': cocoEval.stats[2], | |||||
'Precision/mAP (small)': cocoEval.stats[3], | |||||
'Precision/mAP (medium)': cocoEval.stats[4], | |||||
'Precision/mAP (large)': cocoEval.stats[5], | |||||
'Recall/AR@1': cocoEval.stats[6], | |||||
'Recall/AR@10': cocoEval.stats[7], | |||||
'Recall/AR@100': cocoEval.stats[8], | |||||
'Recall/AR@100 (small)': cocoEval.stats[9], | |||||
'Recall/AR@100 (medium)': cocoEval.stats[10], | |||||
'Recall/AR@100 (large)': cocoEval.stats[11], | |||||
} | |||||
return summary_metrics | |||||
def xyxy2xywh(bbox): | |||||
_bbox = bbox.tolist() | |||||
return [ | |||||
_bbox[0], | |||||
_bbox[1], | |||||
_bbox[2] - _bbox[0] + 1, | |||||
_bbox[3] - _bbox[1] + 1, | |||||
] | |||||
def bbox2result_1image(bboxes, labels, num_classes): | |||||
"""Convert detection results to a list of numpy arrays. | |||||
Args: | |||||
bboxes (Tensor): shape (n, 5) | |||||
labels (Tensor): shape (n, ) | |||||
num_classes (int): class number, including background class | |||||
Returns: | |||||
list(ndarray): bbox results of each class | |||||
""" | |||||
if bboxes.shape[0] == 0: | |||||
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)] | |||||
else: | |||||
result = [bboxes[labels == i, :] for i in range(num_classes - 1)] | |||||
return result | |||||
def proposal2json(dataset, results): | |||||
"""convert proposal to json mode""" | |||||
img_ids = dataset.getImgIds() | |||||
json_results = [] | |||||
dataset_len = dataset.get_dataset_size()*2 | |||||
for idx in range(dataset_len): | |||||
img_id = img_ids[idx] | |||||
bboxes = results[idx] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = 1 | |||||
json_results.append(data) | |||||
return json_results | |||||
def det2json(dataset, results): | |||||
"""convert det to json mode""" | |||||
cat_ids = dataset.getCatIds() | |||||
img_ids = dataset.getImgIds() | |||||
json_results = [] | |||||
dataset_len = len(img_ids) | |||||
for idx in range(dataset_len): | |||||
img_id = img_ids[idx] | |||||
if idx == len(results): break | |||||
result = results[idx] | |||||
for label, result_label in enumerate(result): | |||||
bboxes = result_label | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = cat_ids[label] | |||||
json_results.append(data) | |||||
return json_results | |||||
def segm2json(dataset, results): | |||||
"""convert segm to json mode""" | |||||
bbox_json_results = [] | |||||
segm_json_results = [] | |||||
for idx in range(len(dataset)): | |||||
img_id = dataset.img_ids[idx] | |||||
det, seg = results[idx] | |||||
for label, det_label in enumerate(det): | |||||
# bbox results | |||||
bboxes = det_label | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = dataset.cat_ids[label] | |||||
bbox_json_results.append(data) | |||||
if len(seg) == 2: | |||||
segms = seg[0][label] | |||||
mask_score = seg[1][label] | |||||
else: | |||||
segms = seg[label] | |||||
mask_score = [bbox[4] for bbox in bboxes] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['score'] = float(mask_score[i]) | |||||
data['category_id'] = dataset.cat_ids[label] | |||||
segms[i]['counts'] = segms[i]['counts'].decode() | |||||
data['segmentation'] = segms[i] | |||||
segm_json_results.append(data) | |||||
return bbox_json_results, segm_json_results | |||||
def results2json(dataset, results, out_file): | |||||
"""convert result convert to json mode""" | |||||
result_files = dict() | |||||
if isinstance(results[0], list): | |||||
json_results = det2json(dataset, results) | |||||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') | |||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') | |||||
mmcv.dump(json_results, result_files['bbox']) | |||||
elif isinstance(results[0], tuple): | |||||
json_results = segm2json(dataset, results) | |||||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') | |||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') | |||||
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm') | |||||
mmcv.dump(json_results[0], result_files['bbox']) | |||||
mmcv.dump(json_results[1], result_files['segm']) | |||||
elif isinstance(results[0], np.ndarray): | |||||
json_results = proposal2json(dataset, results) | |||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal') | |||||
mmcv.dump(json_results, result_files['proposal']) | |||||
else: | |||||
raise TypeError('invalid type of results') | |||||
return result_files |
@@ -19,7 +19,7 @@ from abc import abstractmethod | |||||
import numpy as np | import numpy as np | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
from mindspore.nn import Cell | |||||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
@@ -44,12 +44,13 @@ class GradientMethod(Attack): | |||||
Default: None. | Default: None. | ||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: None. | In form of (clip_min, clip_max). Default: None. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientMethod(network) | |||||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -71,9 +72,10 @@ class GradientMethod(Attack): | |||||
else: | else: | ||||
self._alpha = alpha | self._alpha = alpha | ||||
if loss_fn is None: | if loss_fn is None: | ||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
with_loss_cell = WithLossCell(self._network, loss_fn) | |||||
self._grad_all = GradWrapWithLoss(with_loss_cell) | |||||
self._grad_all = self._network | |||||
else: | |||||
with_loss_cell = WithLossCell(self._network, loss_fn) | |||||
self._grad_all = GradWrapWithLoss(with_loss_cell) | |||||
self._grad_all.set_train() | self._grad_all.set_train() | ||||
def generate(self, inputs, labels): | def generate(self, inputs, labels): | ||||
@@ -83,13 +85,19 @@ class GradientMethod(Attack): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to create | inputs (numpy.ndarray): Benign input samples used as references to create | ||||
adversarial examples. | adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if isinstance(labels, tuple): | |||||
for i, labels_item in enumerate(labels): | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels[{}]'.format(i), labels_item) | |||||
else: | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels', labels) | |||||
self._dtype = inputs.dtype | self._dtype = inputs.dtype | ||||
gradient = self._gradient(inputs, labels) | gradient = self._gradient(inputs, labels) | ||||
# use random method or not | # use random method or not | ||||
@@ -117,7 +125,8 @@ class GradientMethod(Attack): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | inputs (numpy.ndarray): Benign input samples used as references to | ||||
create adversarial examples. | create adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Raises: | Raises: | ||||
NotImplementedError: It is an abstract method. | NotImplementedError: It is an abstract method. | ||||
@@ -149,12 +158,13 @@ class FastGradientMethod(GradientMethod): | |||||
Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientMethod(network) | |||||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -175,12 +185,19 @@ class FastGradientMethod(GradientMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input sample. | inputs (numpy.ndarray): Input sample. | ||||
labels (numpy.ndarray): Original/target label. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
""" | """ | ||||
out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) | |||||
if isinstance(labels, tuple): | |||||
labels_tensor = tuple() | |||||
for item in labels: | |||||
labels_tensor += (Tensor(item),) | |||||
else: | |||||
labels_tensor = (Tensor(labels),) | |||||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -210,7 +227,8 @@ class RandomFastGradientMethod(FastGradientMethod): | |||||
Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Raises: | Raises: | ||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
@@ -218,7 +236,7 @@ class RandomFastGradientMethod(FastGradientMethod): | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomFastGradientMethod(network) | |||||
>>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -254,12 +272,13 @@ class FastGradientSignMethod(GradientMethod): | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientSignMethod(network) | |||||
>>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -279,12 +298,19 @@ class FastGradientSignMethod(GradientMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input samples. | inputs (numpy.ndarray): Input samples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (union[numpy.ndarray, tuple]): original/target labels. \ | |||||
for each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
""" | """ | ||||
out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) | |||||
if isinstance(labels, tuple): | |||||
labels_tensor = tuple() | |||||
for item in labels: | |||||
labels_tensor += (Tensor(item),) | |||||
else: | |||||
labels_tensor = (Tensor(labels),) | |||||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -311,7 +337,8 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
is_targeted (bool): True: targeted attack. False: untargeted attack. | is_targeted (bool): True: targeted attack. False: untargeted attack. | ||||
Default: False. | Default: False. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Raises: | Raises: | ||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
@@ -319,7 +346,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomFastGradientSignMethod(network) | |||||
>>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -350,12 +377,13 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||||
Default: None. | Default: None. | ||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = LeastLikelyClassMethod(network) | |||||
>>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -384,7 +412,8 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): | |||||
Default: 0.035. | Default: 0.035. | ||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
Raises: | Raises: | ||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
@@ -392,7 +421,7 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): | |||||
Examples: | Examples: | ||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomLeastLikelyClassMethod(network) | |||||
>>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ from abc import abstractmethod | |||||
import numpy as np | import numpy as np | ||||
from PIL import Image, ImageOps | from PIL import Image, ImageOps | ||||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
from mindspore.nn import Cell | |||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
@@ -114,7 +114,8 @@ class IterativeGradientMethod(Attack): | |||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | ||||
loss_fn=None): | loss_fn=None): | ||||
@@ -123,12 +124,15 @@ class IterativeGradientMethod(Attack): | |||||
self._eps = check_value_positive('eps', eps) | self._eps = check_value_positive('eps', eps) | ||||
self._eps_iter = check_value_positive('eps_iter', eps_iter) | self._eps_iter = check_value_positive('eps_iter', eps_iter) | ||||
self._nb_iter = check_int_positive('nb_iter', nb_iter) | self._nb_iter = check_int_positive('nb_iter', nb_iter) | ||||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||||
for b in self._bounds: | |||||
_ = check_param_multi_types('bound', b, [int, float]) | |||||
self._bounds = None | |||||
if bounds is not None: | |||||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||||
for b in self._bounds: | |||||
_ = check_param_multi_types('bound', b, [int, float]) | |||||
if loss_fn is None: | if loss_fn is None: | ||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) | |||||
self._loss_grad = network | |||||
else: | |||||
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) | |||||
self._loss_grad.set_train() | self._loss_grad.set_train() | ||||
@abstractmethod | @abstractmethod | ||||
@@ -139,8 +143,8 @@ class IterativeGradientMethod(Attack): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to create | inputs (numpy.ndarray): Benign input samples used as references to create | ||||
adversarial examples. | adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Raises: | Raises: | ||||
NotImplementedError: This function is not available in | NotImplementedError: This function is not available in | ||||
IterativeGradientMethod. | IterativeGradientMethod. | ||||
@@ -177,12 +181,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
attack (class): The single step gradient method of each iteration. In | attack (class): The single step gradient method of each iteration. In | ||||
this class, FGSM is used. | this class, FGSM is used. | ||||
Examples: | Examples: | ||||
>>> attack = BasicIterativeMethod(network) | |||||
>>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
is_targeted=False, nb_iter=5, loss_fn=None): | is_targeted=False, nb_iter=5, loss_fn=None): | ||||
@@ -207,8 +212,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | inputs (numpy.ndarray): Benign input samples used as references to | ||||
create adversarial examples. | create adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
@@ -218,8 +223,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | ||||
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if isinstance(labels, tuple): | |||||
for i, labels_item in enumerate(labels): | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels[{}]'.format(i), labels_item) | |||||
else: | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels', labels) | |||||
arr_x = inputs | arr_x = inputs | ||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
@@ -267,7 +277,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
decay_factor (float): Decay factor in iterations. Default: 1.0. | decay_factor (float): Decay factor in iterations. Default: 1.0. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
@@ -290,7 +301,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | inputs (numpy.ndarray): Benign input samples used as references to | ||||
create adversarial examples. | create adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
@@ -301,8 +313,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | ||||
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if isinstance(labels, tuple): | |||||
for i, labels_item in enumerate(labels): | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels[{}]'.format(i), labels_item) | |||||
else: | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels', labels) | |||||
arr_x = inputs | arr_x = inputs | ||||
momentum = 0 | momentum = 0 | ||||
if self._bounds is not None: | if self._bounds is not None: | ||||
@@ -340,7 +357,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input samples. | inputs (numpy.ndarray): Input samples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, gradient of labels w.r.t inputs. | numpy.ndarray, gradient of labels w.r.t inputs. | ||||
@@ -350,7 +368,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | ||||
""" | """ | ||||
# get grad of loss over x | # get grad of loss over x | ||||
out_grad = self._loss_grad(Tensor(inputs), Tensor(labels)) | |||||
if isinstance(labels, tuple): | |||||
labels_tensor = tuple() | |||||
for item in labels: | |||||
labels_tensor += (Tensor(item),) | |||||
else: | |||||
labels_tensor = (Tensor(labels),) | |||||
out_grad = self._loss_grad(Tensor(inputs), *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -384,7 +408,8 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
@@ -406,7 +431,8 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | inputs (numpy.ndarray): Benign input samples used as references to | ||||
create adversarial examples. | create adversarial examples. | ||||
labels (numpy.ndarray): Original/target labels. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||||
For each input if it has more than one label, it is wrapped in a tuple. | |||||
Returns: | Returns: | ||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
@@ -417,8 +443,13 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | ||||
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if isinstance(labels, tuple): | |||||
for i, labels_item in enumerate(labels): | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels[{}]'.format(i), labels_item) | |||||
else: | |||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
'labels', labels) | |||||
arr_x = inputs | arr_x = inputs | ||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
@@ -460,7 +491,8 @@ class DiverseInputIterativeMethod(BasicIterativeMethod): | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
prob (float): Transformation probability. Default: 0.5. | prob (float): Transformation probability. Default: 0.5. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | ||||
is_targeted=False, prob=0.5, loss_fn=None): | is_targeted=False, prob=0.5, loss_fn=None): | ||||
@@ -495,7 +527,8 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'l1'. | np.inf, 1 or 2. Default: 'l1'. | ||||
prob (float): Transformation probability. Default: 0.5. | prob (float): Transformation probability. Default: 0.5. | ||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||||
is already equipped with loss function. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | ||||
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | ||||
@@ -19,6 +19,7 @@ from random import choice | |||||
import numpy as np | import numpy as np | ||||
from mindspore import Model | from mindspore import Model | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore import nn | |||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
check_param_multi_types, check_norm_level, check_param_in_range, \ | check_param_multi_types, check_norm_level, check_param_in_range, \ | ||||
@@ -451,6 +452,8 @@ class Fuzzer: | |||||
else: | else: | ||||
network = self._target_model._network | network = self._target_model._network | ||||
loss_fn = self._target_model._loss_fn | loss_fn = self._target_model._loss_fn | ||||
if loss_fn is None: | |||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
mutates[method] = self._strategies[method](network, | mutates[method] = self._strategies[method](network, | ||||
loss_fn=loss_fn) | loss_fn=loss_fn) | ||||
return mutates | return mutates | ||||
@@ -18,7 +18,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import mindspore.ops.operations as P | import mindspore.ops.operations as P | ||||
from mindspore.nn import Cell | |||||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
import mindspore.context as context | import mindspore.context as context | ||||
from mindarmour.adv_robustness.attacks import FastGradientMethod | from mindarmour.adv_robustness.attacks import FastGradientMethod | ||||
@@ -67,7 +67,7 @@ def test_batch_generate_attack(): | |||||
label = np.random.randint(0, 10, 128).astype(np.int32) | label = np.random.randint(0, 10, 128).astype(np.int32) | ||||
label = np.eye(10)[label].astype(np.float32) | label = np.eye(10)[label].astype(np.float32) | ||||
attack = FastGradientMethod(Net()) | |||||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.batch_generate(input_np, label, batch_size=32) | ms_adv_x = attack.batch_generate(input_np, label, batch_size=32) | ||||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | ||||
@@ -71,7 +71,7 @@ def test_fast_gradient_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = FastGradientMethod(Net()) | |||||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | ||||
@@ -91,7 +91,7 @@ def test_fast_gradient_method_gpu(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = FastGradientMethod(Net()) | |||||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | ||||
@@ -132,7 +132,7 @@ def test_random_fast_gradient_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = RandomFastGradientMethod(Net()) | |||||
attack = RandomFastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Random fast gradient method: ' \ | assert np.any(ms_adv_x != input_np), 'Random fast gradient method: ' \ | ||||
@@ -154,7 +154,7 @@ def test_fast_gradient_sign_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = FastGradientSignMethod(Net()) | |||||
attack = FastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Fast gradient sign method: generate' \ | assert np.any(ms_adv_x != input_np), 'Fast gradient sign method: generate' \ | ||||
@@ -176,7 +176,7 @@ def test_random_fast_gradient_sign_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(28)[label].astype(np.float32) | label = np.eye(28)[label].astype(np.float32) | ||||
attack = RandomFastGradientSignMethod(Net()) | |||||
attack = RandomFastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Random fast gradient sign method: ' \ | assert np.any(ms_adv_x != input_np), 'Random fast gradient sign method: ' \ | ||||
@@ -198,7 +198,7 @@ def test_least_likely_class_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = LeastLikelyClassMethod(Net()) | |||||
attack = LeastLikelyClassMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Least likely class method: generate' \ | assert np.any(ms_adv_x != input_np), 'Least likely class method: generate' \ | ||||
@@ -220,7 +220,8 @@ def test_random_least_likely_class_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01) | |||||
attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01, \ | |||||
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Random least likely class method: ' \ | assert np.any(ms_adv_x != input_np), 'Random least likely class method: ' \ | ||||
@@ -239,5 +240,6 @@ def test_assert_error(): | |||||
""" | """ | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
with pytest.raises(ValueError) as e: | with pytest.raises(ValueError) as e: | ||||
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21) | |||||
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21, \ | |||||
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
assert str(e.value) == 'eps must be larger than alpha!' | assert str(e.value) == 'eps must be larger than alpha!' |
@@ -20,6 +20,7 @@ import pytest | |||||
from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
from mindspore import context | from mindspore import context | ||||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||||
from mindarmour.adv_robustness.attacks import BasicIterativeMethod | from mindarmour.adv_robustness.attacks import BasicIterativeMethod | ||||
from mindarmour.adv_robustness.attacks import MomentumIterativeMethod | from mindarmour.adv_robustness.attacks import MomentumIterativeMethod | ||||
@@ -70,7 +71,7 @@ def test_basic_iterative_method(): | |||||
for i in range(5): | for i in range(5): | ||||
net = Net() | net = Net() | ||||
attack = BasicIterativeMethod(net, nb_iter=i + 1) | |||||
attack = BasicIterativeMethod(net, nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any( | assert np.any( | ||||
ms_adv_x != input_np), 'Basic iterative method: generate value' \ | ms_adv_x != input_np), 'Basic iterative method: generate value' \ | ||||
@@ -91,7 +92,7 @@ def test_momentum_iterative_method(): | |||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
for i in range(5): | for i in range(5): | ||||
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | |||||
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | ||||
' value must not be equal to' \ | ' value must not be equal to' \ | ||||
@@ -112,7 +113,7 @@ def test_projected_gradient_descent_method(): | |||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
for i in range(5): | for i in range(5): | ||||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1) | |||||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any( | assert np.any( | ||||
@@ -134,7 +135,7 @@ def test_diverse_input_iterative_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = DiverseInputIterativeMethod(Net()) | |||||
attack = DiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | ||||
' value must not be equal to' \ | ' value must not be equal to' \ | ||||
@@ -154,7 +155,7 @@ def test_momentum_diverse_input_iterative_method(): | |||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
label = np.eye(3)[label].astype(np.float32) | label = np.eye(3)[label].astype(np.float32) | ||||
attack = MomentumDiverseInputIterativeMethod(Net()) | |||||
attack = MomentumDiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | ||||
'generate value must not be equal to' \ | 'generate value must not be equal to' \ | ||||
@@ -167,10 +168,7 @@ def test_momentum_diverse_input_iterative_method(): | |||||
@pytest.mark.env_card | @pytest.mark.env_card | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_error(): | def test_error(): | ||||
with pytest.raises(TypeError): | |||||
# check_param_multi_types | |||||
assert IterativeGradientMethod(Net(), bounds=None) | |||||
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0)) | |||||
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | ||||
label = np.asarray([2], np.int32) | label = np.asarray([2], np.int32) | ||||
@@ -59,8 +59,8 @@ def test_ead(): | |||||
optimizer = Momentum(net.trainable_params(), 0.001, 0.9) | optimizer = Momentum(net.trainable_params(), 0.001, 0.9) | ||||
net = Net() | net = Net() | ||||
fgsm = FastGradientSignMethod(net) | |||||
pgd = ProjectedGradientDescent(net) | |||||
fgsm = FastGradientSignMethod(net, loss_fn=loss_fn) | |||||
pgd = ProjectedGradientDescent(net, loss_fn=loss_fn) | |||||
ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn, | ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn, | ||||
optimizer=optimizer) | optimizer=optimizer) | ||||
LOGGER.set_level(logging.DEBUG) | LOGGER.set_level(logging.DEBUG) | ||||
@@ -117,7 +117,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
# generate adv_data | # generate adv_data | ||||
attack = FastGradientSignMethod(net, eps=0.3) | |||||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | ||||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | ||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||