Browse Source

!137 Extend Deepfool to object detection models

From: @liu_luobin
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
971b455b6f
7 changed files with 346 additions and 175 deletions
  1. +2
    -13
      mindarmour/adv_robustness/attacks/attack.py
  2. +217
    -66
      mindarmour/adv_robustness/attacks/deep_fool.py
  3. +7
    -39
      mindarmour/adv_robustness/attacks/gradient_method.py
  4. +7
    -51
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  5. +19
    -0
      mindarmour/utils/_check_param.py
  6. +56
    -6
      mindarmour/utils/util.py
  7. +38
    -0
      tests/ut/python/adv_robustness/attacks/test_deep_fool.py

+ 2
- 13
mindarmour/adv_robustness/attacks/attack.py View File

@@ -18,7 +18,7 @@ from abc import abstractmethod


import numpy as np import numpy as np


from mindarmour.utils._check_param import check_pair_numpy_param, \
from mindarmour.utils._check_param import check_inputs_labels, \
check_int_positive, check_equal_shape, check_numpy_param, check_model check_int_positive, check_equal_shape, check_numpy_param, check_model
from mindarmour.utils.util import calculate_iou from mindarmour.utils.util import calculate_iou
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
@@ -55,18 +55,7 @@ class Attack:
>>> labels = np.array([3, 0]) >>> labels = np.array([3, 0])
>>> advs = attack.batch_generate(inputs, labels, batch_size=2) >>> advs = attack.batch_generate(inputs, labels, batch_size=2)
""" """
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs arr_x = inputs
arr_y = labels arr_y = labels
len_x = inputs_image.shape[0] len_x = inputs_image.shape[0]


+ 217
- 66
mindarmour/adv_robustness/attacks/deep_fool.py View File

@@ -20,16 +20,76 @@ from mindspore import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell


from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import GradWrap, jacobian_matrix
from mindarmour.utils.util import GradWrap, jacobian_matrix, \
jacobian_matrix_for_detection, calculate_iou, to_tensor_tuple
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
check_value_positive, check_int_positive, check_norm_level, \ check_value_positive, check_int_positive, check_norm_level, \
check_param_multi_types, check_param_type
check_param_multi_types, check_param_type, check_value_non_negative
from .attack import Attack from .attack import Attack


LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = 'DeepFool' TAG = 'DeepFool'




class _GetLogits(Cell):
def __init__(self, network):
super(_GetLogits, self).__init__()
self._network = network

def construct(self, *inputs):
_, pre_logits = self._network(*inputs)
return pre_logits


def _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network):
"""
Evaluate the detection result of inputs, specially for object detection models.

Args:
inputs (numpy.ndarray): Input samples.
gt_boxes (numpy.ndarray): Ground-truth boxes of inputs.
gt_labels (numpy.ndarray): Ground-truth labels of inputs.
model (BlackModel): Target model.

Returns:
- numpy.ndarray, detection scores of inputs.

- numpy.ndarray, the number of objects that are correctly detected.
"""
network = check_param_type('network', network, Cell)
inputs_tensor = to_tensor_tuple(inputs)
box_and_confi, pred_logits = network(*inputs_tensor)
box_and_confi, pred_logits = box_and_confi.asnumpy(), pred_logits.asnumpy()
pred_labels = np.argmax(pred_logits, axis=2)
det_scores = []
correct_labels_num = []
gt_boxes_num = gt_boxes.shape[0]
iou_thres = 0.5
for idx, (boxes, labels) in enumerate(zip(box_and_confi, pred_labels)):
score = 0
box_num = boxes.shape[0]
correct_label_flag = np.zeros(gt_labels.shape)
gt_boxes_idx = gt_boxes[idx]
gt_labels_idx = gt_labels[idx]
for i in range(box_num):
pred_box = boxes[i]
max_iou_confi = 0
for j in range(gt_boxes_num):
iou = calculate_iou(pred_box[:4], gt_boxes_idx[j][:4])
if labels[i] == gt_labels_idx[j] and iou > iou_thres:
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou)
correct_label_flag[j] = 1
score += max_iou_confi
det_scores.append(score)
correct_labels_num.append(np.sum(correct_label_flag))
return np.array(det_scores), np.array(correct_labels_num)


def _is_success(inputs, gt_boxes, gt_labels, network, gt_object_nums, reserve_ratio):
_, correct_nums_adv = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network)
return np.all(correct_nums_adv <= (gt_object_nums*reserve_ratio).astype(np.int32))


class DeepFool(Attack): class DeepFool(Attack):
""" """
DeepFool is an untargeted & iterative attack achieved by moving the benign DeepFool is an untargeted & iterative attack achieved by moving the benign
@@ -56,8 +116,8 @@ class DeepFool(Attack):
>>> attack = DeepFool(network) >>> attack = DeepFool(network)
""" """


def __init__(self, network, num_classes, max_iters=50, overshoot=0.02,
norm_level=2, bounds=None, sparse=True):
def __init__(self, network, num_classes, model_type='classification',
reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True):
super(DeepFool, self).__init__() super(DeepFool, self).__init__()
self._network = check_model('network', network, Cell) self._network = check_model('network', network, Cell)
self._network.set_grad(True) self._network.set_grad(True)
@@ -66,18 +126,32 @@ class DeepFool(Attack):
self._norm_level = check_norm_level(norm_level) self._norm_level = check_norm_level(norm_level)
self._num_classes = check_int_positive('num_classes', num_classes) self._num_classes = check_int_positive('num_classes', num_classes)
self._net_grad = GradWrap(self._network) self._net_grad = GradWrap(self._network)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
self._bounds = bounds
if self._bounds is not None:
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._sparse = check_param_type('sparse', sparse, bool) self._sparse = check_param_type('sparse', sparse, bool)
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._model_type = check_param_type('model_type', model_type, str)
if self._model_type not in ('classification', 'detection'):
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio)
if self._reserve_ratio > 1:
msg = 'reserve_ratio should be less than 1.0, but got {}.'.format(self._reserve_ratio)
LOGGER.error(TAG, msg)
raise ValueError(TAG, msg)


def generate(self, inputs, labels): def generate(self, inputs, labels):
""" """
Generate adversarial examples based on input samples and original labels. Generate adversarial examples based on input samples and original labels.


Args: Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Original labels.
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \
or only one array if model_type='detection'
labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \
(gt_boxes, gt_labels) if model_type='detection'.


Returns: Returns:
numpy.ndarray, adversarial examples. numpy.ndarray, adversarial examples.
@@ -88,67 +162,144 @@ class DeepFool(Attack):
Examples: Examples:
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
""" """
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
for k in range(self._num_classes):
if k == label:
if self._model_type == 'detection':
images, auxiliary_inputs = inputs[0], inputs[1:]
gt_boxes, gt_labels = labels
_, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network)
if not self._sparse:
gt_labels = np.argmax(gt_labels, axis=2)
origin_labels = np.zeros(gt_labels.shape[0])
for i in range(gt_labels.shape[0]):
origin_labels[i] = np.argmax(np.bincount(gt_labels[i]))
images_dtype = images.dtype
iteration = 0
num_boxes = gt_labels.shape[1]
merge_net = _GetLogits(self._network)
detection_net_grad = GradWrap(merge_net)
weight = np.squeeze(np.zeros(images.shape[1:]))
r_tot = np.zeros(images.shape)
x_origin = images
while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \
self._reserve_ratio) and iteration < self._max_iters:
preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy()
grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs,
num_boxes, self._num_classes)
for idx in range(images.shape[0]):
diff_w = np.inf
label = int(origin_labels[idx])
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[idx], axis=0),)
gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0)
gt_labels_i = np.expand_dims(gt_labels[idx], axis=0)
inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i
if _is_success(inputs_i, gt_boxes_i, gt_labels_i,
self._network, gt_object_nums[idx], self._reserve_ratio):
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...]))
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
msg = 'ord {} is not available in normalization,' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
else:
images = x_origin + (1 + self._overshoot)*r_tot
iteration += 1
images = images.astype(images_dtype)
del preds_logits, grads
return images

if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

if self._norm_level == 2 or self._norm_level == '2': if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else: else:
msg = 'ord {} is not available.' \
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level)) .format(str(self._norm_level))
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise NotImplementedError(msg) raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else: else:
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else:
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs
return None

+ 7
- 39
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -18,12 +18,11 @@ from abc import abstractmethod


import numpy as np import numpy as np


from mindspore import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell


from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
from mindarmour.utils._check_param import check_model, check_inputs_labels, \
normalize_value, check_value_positive, check_param_multi_types, \ normalize_value, check_value_positive, check_param_multi_types, \
check_norm_level, check_param_type check_norm_level, check_param_type
from .attack import Attack from .attack import Attack
@@ -91,18 +90,7 @@ class GradientMethod(Attack):
Returns: Returns:
numpy.ndarray, generated adversarial examples. numpy.ndarray, generated adversarial examples.
""" """
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
self._dtype = inputs_image.dtype self._dtype = inputs_image.dtype
gradient = self._gradient(inputs, labels) gradient = self._gradient(inputs, labels)
# use random method or not # use random method or not
@@ -196,18 +184,8 @@ class FastGradientMethod(GradientMethod):
Returns: Returns:
numpy.ndarray, gradient of inputs. numpy.ndarray, gradient of inputs.
""" """
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) out_grad = self._grad_all(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple): if isinstance(out_grad, tuple):
out_grad = out_grad[0] out_grad = out_grad[0]
@@ -315,18 +293,8 @@ class FastGradientSignMethod(GradientMethod):
Returns: Returns:
numpy.ndarray, gradient of inputs. numpy.ndarray, gradient of inputs.
""" """
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) out_grad = self._grad_all(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple): if isinstance(out_grad, tuple):
out_grad = out_grad[0] out_grad = out_grad[0]


+ 7
- 51
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -18,11 +18,10 @@ import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps


from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore import Tensor


from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils._check_param import check_pair_numpy_param, \
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple
from mindarmour.utils._check_param import check_inputs_labels, \
normalize_value, check_model, check_value_positive, check_int_positive, \ normalize_value, check_model, check_value_positive, check_int_positive, \
check_param_type, check_norm_level, check_param_multi_types check_param_type, check_norm_level, check_param_multi_types
from .attack import Attack from .attack import Attack
@@ -223,18 +222,7 @@ class BasicIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
""" """
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image arr_x = inputs_image
if self._bounds is not None: if self._bounds is not None:
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds
@@ -322,18 +310,7 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])
""" """
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image arr_x = inputs_image
momentum = 0 momentum = 0
if self._bounds is not None: if self._bounds is not None:
@@ -392,18 +369,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
""" """
# get grad of loss over x # get grad of loss over x
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) out_grad = self._loss_grad(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple): if isinstance(out_grad, tuple):
out_grad = out_grad[0] out_grad = out_grad[0]
@@ -473,18 +440,7 @@ class ProjectedGradientDescent(BasicIterativeMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
""" """
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image arr_x = inputs_image
if self._bounds is not None: if self._bounds is not None:
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds


+ 19
- 0
mindarmour/utils/_check_param.py View File

@@ -327,3 +327,22 @@ def check_detection_inputs(inputs, labels):
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
return images, auxiliary_inputs, gt_boxes, gt_labels return images, auxiliary_inputs, gt_boxes, gt_labels


def check_inputs_labels(inputs, labels):
"""check inputs and labels is valid for white box method."""
_ = check_param_multi_types('inputs', inputs, (tuple, np.ndarray))
_ = check_param_multi_types('labels', labels, (tuple, np.ndarray))
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
return inputs_image, inputs, labels

+ 56
- 6
mindarmour/utils/util.py View File

@@ -17,7 +17,7 @@ from mindspore import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation from mindspore.ops.composite import GradOperation


from mindarmour.utils._check_param import check_numpy_param
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types


from .logger import LogUtil from .logger import LogUtil


@@ -54,6 +54,44 @@ def jacobian_matrix(grad_wrap_net, inputs, num_classes):
return np.asarray(grads_matrix) return np.asarray(grads_matrix)




def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes):
"""
Calculate the Jacobian matrix for inputs, specifically for object detection model.

Args:
grad_wrap_net (Cell): A network wrapped by GradWrap.
inputs (numpy.ndarray): Input samples.
num_boxes (int): Number of boxes infered by each image.
num_classes (int): Number of labels of model output.

Returns:
numpy.ndarray, the Jacobian matrix of inputs. (labels, batch_size, ...)

Raises:
ValueError: If grad_wrap_net is not a instance of class `GradWrap`.
"""
if not isinstance(grad_wrap_net, GradWrap):
msg = 'grad_wrap_net be and instance of class `GradWrap`.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
grad_wrap_net.set_train()
grads_matrix = []
inputs_tensor = tuple()
if isinstance(inputs, tuple):
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor += (Tensor(inputs),)
for idx in range(num_classes):
batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0]
sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32)
sens[:, :, idx] = 1.0

grads = grad_wrap_net(*(inputs_tensor), Tensor(sens))
grads_matrix.append(grads.asnumpy())
return np.asarray(grads_matrix)


class WithLossCell(Cell): class WithLossCell(Cell):
""" """
Wrap the network with loss function. Wrap the network with loss function.
@@ -152,19 +190,19 @@ class GradWrap(Cell):
self.grad = GradOperation(get_all=False, sens_param=True) self.grad = GradOperation(get_all=False, sens_param=True)
self.network = network self.network = network


def construct(self, inputs, weight):
def construct(self, *data):
""" """
Compute jacobian matrix. Compute jacobian matrix.


Args: Args:
inputs (Tensor): Inputs of network.
weight (Tensor): Weight of each gradient, `weight` has the same
shape with labels.
data (Tensor): Data consists of inputs and weight. \
- inputs: Inputs of network. \
- weight: Weight of each gradient, 'weight' has the same shape with labels.


Returns: Returns:
Tensor, Jacobian matrix. Tensor, Jacobian matrix.
""" """
gout = self.grad(self.network)(inputs, weight)
gout = self.grad(self.network)(*data)
return gout return gout




@@ -199,3 +237,15 @@ def calculate_iou(box_i, box_j):
return 0 return 0
inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line)
return inner_area / (s_i + s_j - inner_area) return inner_area / (s_i + s_j - inner_area)


def to_tensor_tuple(inputs_ori):
"""Transfer inputs data into tensor type."""
inputs_ori = check_param_multi_types('inputs_ori', inputs_ori, [np.ndarray, tuple])
if isinstance(inputs_ori, tuple):
inputs_tensor = tuple()
for item in inputs_ori:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs_ori),)
return inputs_tensor

+ 38
- 0
tests/ut/python/adv_robustness/attacks/test_deep_fool.py View File

@@ -54,6 +54,23 @@ class Net(Cell):
return out return out




class Net2(Cell):
"""
Construct the network of target model, specifically for detection model test case.

Examples:
>>> net = Net2()
"""
def __init__(self):
super(Net2, self).__init__()
self._softmax = P.Softmax()

def construct(self, inputs1, inputs2):
out1 = self._softmax(inputs2)
out2 = self._softmax(inputs1)
return out1, out2


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@@ -79,6 +96,27 @@ def test_deepfool_attack():
' implementation error, ms_adv_x != expect_value' ' implementation error, ms_adv_x != expect_value'




@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_deepfool_attack_detection():
"""
Deepfool-Attack test
"""
net = Net2()
inputs1_np = np.random.random((2, 10, 10)).astype(np.float32)
inputs2_np = np.random.random((2, 10, 5)).astype(np.float32)
gt_boxes = inputs1_np[:, :, 0: 5]
gt_labels = np.argmax(inputs1_np, axis=2)
num_classes = 10

attack = DeepFool(net, num_classes, model_type='detection', reserve_ratio=0.3,
bounds=(0.0, 1.0))
_ = attack.generate((inputs1_np, inputs2_np), (gt_boxes, gt_labels))


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training


Loading…
Cancel
Save