| @@ -18,7 +18,7 @@ from abc import abstractmethod | |||||
| import numpy as np | import numpy as np | ||||
| from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||||
| from mindarmour.utils._check_param import check_inputs_labels, \ | |||||
| check_int_positive, check_equal_shape, check_numpy_param, check_model | check_int_positive, check_equal_shape, check_numpy_param, check_model | ||||
| from mindarmour.utils.util import calculate_iou | from mindarmour.utils.util import calculate_iou | ||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| @@ -55,18 +55,7 @@ class Attack: | |||||
| >>> labels = np.array([3, 0]) | >>> labels = np.array([3, 0]) | ||||
| >>> advs = attack.batch_generate(inputs, labels, batch_size=2) | >>> advs = attack.batch_generate(inputs, labels, batch_size=2) | ||||
| """ | """ | ||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||||
| arr_x = inputs | arr_x = inputs | ||||
| arr_y = labels | arr_y = labels | ||||
| len_x = inputs_image.shape[0] | len_x = inputs_image.shape[0] | ||||
| @@ -20,16 +20,76 @@ from mindspore import Tensor | |||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from mindarmour.utils.util import GradWrap, jacobian_matrix | |||||
| from mindarmour.utils.util import GradWrap, jacobian_matrix, \ | |||||
| jacobian_matrix_for_detection, calculate_iou, to_tensor_tuple | |||||
| from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ | from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ | ||||
| check_value_positive, check_int_positive, check_norm_level, \ | check_value_positive, check_int_positive, check_norm_level, \ | ||||
| check_param_multi_types, check_param_type | |||||
| check_param_multi_types, check_param_type, check_value_non_negative | |||||
| from .attack import Attack | from .attack import Attack | ||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| TAG = 'DeepFool' | TAG = 'DeepFool' | ||||
| class _GetLogits(Cell): | |||||
| def __init__(self, network): | |||||
| super(_GetLogits, self).__init__() | |||||
| self._network = network | |||||
| def construct(self, *inputs): | |||||
| _, pre_logits = self._network(*inputs) | |||||
| return pre_logits | |||||
| def _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network): | |||||
| """ | |||||
| Evaluate the detection result of inputs, specially for object detection models. | |||||
| Args: | |||||
| inputs (numpy.ndarray): Input samples. | |||||
| gt_boxes (numpy.ndarray): Ground-truth boxes of inputs. | |||||
| gt_labels (numpy.ndarray): Ground-truth labels of inputs. | |||||
| model (BlackModel): Target model. | |||||
| Returns: | |||||
| - numpy.ndarray, detection scores of inputs. | |||||
| - numpy.ndarray, the number of objects that are correctly detected. | |||||
| """ | |||||
| network = check_param_type('network', network, Cell) | |||||
| inputs_tensor = to_tensor_tuple(inputs) | |||||
| box_and_confi, pred_logits = network(*inputs_tensor) | |||||
| box_and_confi, pred_logits = box_and_confi.asnumpy(), pred_logits.asnumpy() | |||||
| pred_labels = np.argmax(pred_logits, axis=2) | |||||
| det_scores = [] | |||||
| correct_labels_num = [] | |||||
| gt_boxes_num = gt_boxes.shape[0] | |||||
| iou_thres = 0.5 | |||||
| for idx, (boxes, labels) in enumerate(zip(box_and_confi, pred_labels)): | |||||
| score = 0 | |||||
| box_num = boxes.shape[0] | |||||
| correct_label_flag = np.zeros(gt_labels.shape) | |||||
| gt_boxes_idx = gt_boxes[idx] | |||||
| gt_labels_idx = gt_labels[idx] | |||||
| for i in range(box_num): | |||||
| pred_box = boxes[i] | |||||
| max_iou_confi = 0 | |||||
| for j in range(gt_boxes_num): | |||||
| iou = calculate_iou(pred_box[:4], gt_boxes_idx[j][:4]) | |||||
| if labels[i] == gt_labels_idx[j] and iou > iou_thres: | |||||
| max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) | |||||
| correct_label_flag[j] = 1 | |||||
| score += max_iou_confi | |||||
| det_scores.append(score) | |||||
| correct_labels_num.append(np.sum(correct_label_flag)) | |||||
| return np.array(det_scores), np.array(correct_labels_num) | |||||
| def _is_success(inputs, gt_boxes, gt_labels, network, gt_object_nums, reserve_ratio): | |||||
| _, correct_nums_adv = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network) | |||||
| return np.all(correct_nums_adv <= (gt_object_nums*reserve_ratio).astype(np.int32)) | |||||
| class DeepFool(Attack): | class DeepFool(Attack): | ||||
| """ | """ | ||||
| DeepFool is an untargeted & iterative attack achieved by moving the benign | DeepFool is an untargeted & iterative attack achieved by moving the benign | ||||
| @@ -56,8 +116,8 @@ class DeepFool(Attack): | |||||
| >>> attack = DeepFool(network) | >>> attack = DeepFool(network) | ||||
| """ | """ | ||||
| def __init__(self, network, num_classes, max_iters=50, overshoot=0.02, | |||||
| norm_level=2, bounds=None, sparse=True): | |||||
| def __init__(self, network, num_classes, model_type='classification', | |||||
| reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True): | |||||
| super(DeepFool, self).__init__() | super(DeepFool, self).__init__() | ||||
| self._network = check_model('network', network, Cell) | self._network = check_model('network', network, Cell) | ||||
| self._network.set_grad(True) | self._network.set_grad(True) | ||||
| @@ -66,18 +126,32 @@ class DeepFool(Attack): | |||||
| self._norm_level = check_norm_level(norm_level) | self._norm_level = check_norm_level(norm_level) | ||||
| self._num_classes = check_int_positive('num_classes', num_classes) | self._num_classes = check_int_positive('num_classes', num_classes) | ||||
| self._net_grad = GradWrap(self._network) | self._net_grad = GradWrap(self._network) | ||||
| self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||||
| self._bounds = bounds | |||||
| if self._bounds is not None: | |||||
| self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||||
| for b in self._bounds: | |||||
| _ = check_param_multi_types('bound', b, [int, float]) | |||||
| self._sparse = check_param_type('sparse', sparse, bool) | self._sparse = check_param_type('sparse', sparse, bool) | ||||
| for b in self._bounds: | |||||
| _ = check_param_multi_types('bound', b, [int, float]) | |||||
| self._model_type = check_param_type('model_type', model_type, str) | |||||
| if self._model_type not in ('classification', 'detection'): | |||||
| msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | |||||
| if self._reserve_ratio > 1: | |||||
| msg = 'reserve_ratio should be less than 1.0, but got {}.'.format(self._reserve_ratio) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(TAG, msg) | |||||
| def generate(self, inputs, labels): | def generate(self, inputs, labels): | ||||
| """ | """ | ||||
| Generate adversarial examples based on input samples and original labels. | Generate adversarial examples based on input samples and original labels. | ||||
| Args: | Args: | ||||
| inputs (numpy.ndarray): Input samples. | |||||
| labels (numpy.ndarray): Original labels. | |||||
| inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \ | |||||
| or only one array if model_type='detection' | |||||
| labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \ | |||||
| (gt_boxes, gt_labels) if model_type='detection'. | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, adversarial examples. | numpy.ndarray, adversarial examples. | ||||
| @@ -88,67 +162,144 @@ class DeepFool(Attack): | |||||
| Examples: | Examples: | ||||
| >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) | >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) | ||||
| """ | """ | ||||
| inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
| 'labels', labels) | |||||
| if not self._sparse: | |||||
| labels = np.argmax(labels, axis=1) | |||||
| inputs_dtype = inputs.dtype | |||||
| iteration = 0 | |||||
| origin_labels = labels | |||||
| cur_labels = origin_labels.copy() | |||||
| weight = np.squeeze(np.zeros(inputs.shape[1:])) | |||||
| r_tot = np.zeros(inputs.shape) | |||||
| x_origin = inputs | |||||
| while np.any(cur_labels == origin_labels) and iteration < self._max_iters: | |||||
| preds = self._network(Tensor(inputs)).asnumpy() | |||||
| grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) | |||||
| for idx in range(inputs.shape[0]): | |||||
| diff_w = np.inf | |||||
| label = origin_labels[idx] | |||||
| if cur_labels[idx] != label: | |||||
| continue | |||||
| for k in range(self._num_classes): | |||||
| if k == label: | |||||
| if self._model_type == 'detection': | |||||
| images, auxiliary_inputs = inputs[0], inputs[1:] | |||||
| gt_boxes, gt_labels = labels | |||||
| _, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network) | |||||
| if not self._sparse: | |||||
| gt_labels = np.argmax(gt_labels, axis=2) | |||||
| origin_labels = np.zeros(gt_labels.shape[0]) | |||||
| for i in range(gt_labels.shape[0]): | |||||
| origin_labels[i] = np.argmax(np.bincount(gt_labels[i])) | |||||
| images_dtype = images.dtype | |||||
| iteration = 0 | |||||
| num_boxes = gt_labels.shape[1] | |||||
| merge_net = _GetLogits(self._network) | |||||
| detection_net_grad = GradWrap(merge_net) | |||||
| weight = np.squeeze(np.zeros(images.shape[1:])) | |||||
| r_tot = np.zeros(images.shape) | |||||
| x_origin = images | |||||
| while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \ | |||||
| self._reserve_ratio) and iteration < self._max_iters: | |||||
| preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy() | |||||
| grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs, | |||||
| num_boxes, self._num_classes) | |||||
| for idx in range(images.shape[0]): | |||||
| diff_w = np.inf | |||||
| label = int(origin_labels[idx]) | |||||
| auxiliary_input_i = tuple() | |||||
| for item in auxiliary_inputs: | |||||
| auxiliary_input_i += (np.expand_dims(item[idx], axis=0),) | |||||
| gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0) | |||||
| gt_labels_i = np.expand_dims(gt_labels[idx], axis=0) | |||||
| inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i | |||||
| if _is_success(inputs_i, gt_boxes_i, gt_labels_i, | |||||
| self._network, gt_object_nums[idx], self._reserve_ratio): | |||||
| continue | |||||
| for k in range(self._num_classes): | |||||
| if k == label: | |||||
| continue | |||||
| w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||||
| f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...])) | |||||
| if self._norm_level == 2 or self._norm_level == '2': | |||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||||
| elif self._norm_level == np.inf \ | |||||
| or self._norm_level == 'inf': | |||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||||
| else: | |||||
| msg = 'ord {} is not available.' \ | |||||
| .format(str(self._norm_level)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise NotImplementedError(msg) | |||||
| if diff_w_k < diff_w: | |||||
| diff_w = diff_w_k | |||||
| weight = w_k | |||||
| if self._norm_level == 2 or self._norm_level == '2': | |||||
| r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||||
| elif self._norm_level == np.inf or self._norm_level == 'inf': | |||||
| r_i = diff_w*np.sign(weight) \ | |||||
| / (np.linalg.norm(weight, ord=1) + 1e-8) | |||||
| else: | |||||
| msg = 'ord {} is not available in normalization,' \ | |||||
| .format(str(self._norm_level)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise NotImplementedError(msg) | |||||
| r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||||
| if self._bounds is not None: | |||||
| clip_min, clip_max = self._bounds | |||||
| images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) | |||||
| images = np.clip(images, clip_min, clip_max) | |||||
| else: | |||||
| images = x_origin + (1 + self._overshoot)*r_tot | |||||
| iteration += 1 | |||||
| images = images.astype(images_dtype) | |||||
| del preds_logits, grads | |||||
| return images | |||||
| if self._model_type == 'classification': | |||||
| inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
| 'labels', labels) | |||||
| if not self._sparse: | |||||
| labels = np.argmax(labels, axis=1) | |||||
| inputs_dtype = inputs.dtype | |||||
| iteration = 0 | |||||
| origin_labels = labels | |||||
| cur_labels = origin_labels.copy() | |||||
| weight = np.squeeze(np.zeros(inputs.shape[1:])) | |||||
| r_tot = np.zeros(inputs.shape) | |||||
| x_origin = inputs | |||||
| while np.any(cur_labels == origin_labels) and iteration < self._max_iters: | |||||
| preds = self._network(Tensor(inputs)).asnumpy() | |||||
| grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) | |||||
| for idx in range(inputs.shape[0]): | |||||
| diff_w = np.inf | |||||
| label = origin_labels[idx] | |||||
| if cur_labels[idx] != label: | |||||
| continue | continue | ||||
| w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||||
| f_k = preds[idx, k] - preds[idx, label] | |||||
| for k in range(self._num_classes): | |||||
| if k == label: | |||||
| continue | |||||
| w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||||
| f_k = preds[idx, k] - preds[idx, label] | |||||
| if self._norm_level == 2 or self._norm_level == '2': | |||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||||
| elif self._norm_level == np.inf \ | |||||
| or self._norm_level == 'inf': | |||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||||
| else: | |||||
| msg = 'ord {} is not available.' \ | |||||
| .format(str(self._norm_level)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise NotImplementedError(msg) | |||||
| if diff_w_k < diff_w: | |||||
| diff_w = diff_w_k | |||||
| weight = w_k | |||||
| if self._norm_level == 2 or self._norm_level == '2': | if self._norm_level == 2 or self._norm_level == '2': | ||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||||
| elif self._norm_level == np.inf \ | |||||
| or self._norm_level == 'inf': | |||||
| diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||||
| r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||||
| elif self._norm_level == np.inf or self._norm_level == 'inf': | |||||
| r_i = diff_w*np.sign(weight) \ | |||||
| / (np.linalg.norm(weight, ord=1) + 1e-8) | |||||
| else: | else: | ||||
| msg = 'ord {} is not available.' \ | |||||
| msg = 'ord {} is not available in normalization.' \ | |||||
| .format(str(self._norm_level)) | .format(str(self._norm_level)) | ||||
| LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
| raise NotImplementedError(msg) | raise NotImplementedError(msg) | ||||
| if diff_w_k < diff_w: | |||||
| diff_w = diff_w_k | |||||
| weight = w_k | |||||
| if self._norm_level == 2 or self._norm_level == '2': | |||||
| r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||||
| elif self._norm_level == np.inf or self._norm_level == 'inf': | |||||
| r_i = diff_w*np.sign(weight) \ | |||||
| / (np.linalg.norm(weight, ord=1) + 1e-8) | |||||
| r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||||
| if self._bounds is not None: | |||||
| clip_min, clip_max = self._bounds | |||||
| inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max | |||||
| - clip_min) | |||||
| inputs = np.clip(inputs, clip_min, clip_max) | |||||
| else: | else: | ||||
| msg = 'ord {} is not available in normalization.' \ | |||||
| .format(str(self._norm_level)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise NotImplementedError(msg) | |||||
| r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||||
| if self._bounds is not None: | |||||
| clip_min, clip_max = self._bounds | |||||
| inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max | |||||
| - clip_min) | |||||
| inputs = np.clip(inputs, clip_min, clip_max) | |||||
| else: | |||||
| inputs = x_origin + (1 + self._overshoot)*r_tot | |||||
| cur_labels = np.argmax( | |||||
| self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), | |||||
| axis=1) | |||||
| iteration += 1 | |||||
| inputs = inputs.astype(inputs_dtype) | |||||
| del preds, grads | |||||
| return inputs | |||||
| inputs = x_origin + (1 + self._overshoot)*r_tot | |||||
| cur_labels = np.argmax( | |||||
| self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), | |||||
| axis=1) | |||||
| iteration += 1 | |||||
| inputs = inputs.astype(inputs_dtype) | |||||
| del preds, grads | |||||
| return inputs | |||||
| return None | |||||
| @@ -18,12 +18,11 @@ from abc import abstractmethod | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | |||||
| from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ | |||||
| from mindarmour.utils._check_param import check_model, check_inputs_labels, \ | |||||
| normalize_value, check_value_positive, check_param_multi_types, \ | normalize_value, check_value_positive, check_param_multi_types, \ | ||||
| check_norm_level, check_param_type | check_norm_level, check_param_type | ||||
| from .attack import Attack | from .attack import Attack | ||||
| @@ -91,18 +90,7 @@ class GradientMethod(Attack): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
| """ | """ | ||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||||
| self._dtype = inputs_image.dtype | self._dtype = inputs_image.dtype | ||||
| gradient = self._gradient(inputs, labels) | gradient = self._gradient(inputs, labels) | ||||
| # use random method or not | # use random method or not | ||||
| @@ -196,18 +184,8 @@ class FastGradientMethod(GradientMethod): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
| """ | """ | ||||
| if isinstance(inputs, tuple): | |||||
| inputs_tensor = tuple() | |||||
| for item in inputs: | |||||
| inputs_tensor += (Tensor(item),) | |||||
| else: | |||||
| inputs_tensor = (Tensor(inputs),) | |||||
| if isinstance(labels, tuple): | |||||
| labels_tensor = tuple() | |||||
| for item in labels: | |||||
| labels_tensor += (Tensor(item),) | |||||
| else: | |||||
| labels_tensor = (Tensor(labels),) | |||||
| inputs_tensor = to_tensor_tuple(inputs) | |||||
| labels_tensor = to_tensor_tuple(labels) | |||||
| out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | ||||
| if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
| out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
| @@ -315,18 +293,8 @@ class FastGradientSignMethod(GradientMethod): | |||||
| Returns: | Returns: | ||||
| numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
| """ | """ | ||||
| if isinstance(inputs, tuple): | |||||
| inputs_tensor = tuple() | |||||
| for item in inputs: | |||||
| inputs_tensor += (Tensor(item),) | |||||
| else: | |||||
| inputs_tensor = (Tensor(inputs),) | |||||
| if isinstance(labels, tuple): | |||||
| labels_tensor = tuple() | |||||
| for item in labels: | |||||
| labels_tensor += (Tensor(item),) | |||||
| else: | |||||
| labels_tensor = (Tensor(labels),) | |||||
| inputs_tensor = to_tensor_tuple(inputs) | |||||
| labels_tensor = to_tensor_tuple(labels) | |||||
| out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | ||||
| if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
| out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
| @@ -18,11 +18,10 @@ import numpy as np | |||||
| from PIL import Image, ImageOps | from PIL import Image, ImageOps | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore import Tensor | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | |||||
| from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||||
| from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple | |||||
| from mindarmour.utils._check_param import check_inputs_labels, \ | |||||
| normalize_value, check_model, check_value_positive, check_int_positive, \ | normalize_value, check_model, check_value_positive, check_int_positive, \ | ||||
| check_param_type, check_norm_level, check_param_multi_types | check_param_type, check_norm_level, check_param_multi_types | ||||
| from .attack import Attack | from .attack import Attack | ||||
| @@ -223,18 +222,7 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
| >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | ||||
| >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | ||||
| """ | """ | ||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||||
| arr_x = inputs_image | arr_x = inputs_image | ||||
| if self._bounds is not None: | if self._bounds is not None: | ||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| @@ -322,18 +310,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | ||||
| >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | ||||
| """ | """ | ||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||||
| arr_x = inputs_image | arr_x = inputs_image | ||||
| momentum = 0 | momentum = 0 | ||||
| if self._bounds is not None: | if self._bounds is not None: | ||||
| @@ -392,18 +369,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | ||||
| """ | """ | ||||
| # get grad of loss over x | # get grad of loss over x | ||||
| if isinstance(inputs, tuple): | |||||
| inputs_tensor = tuple() | |||||
| for item in inputs: | |||||
| inputs_tensor += (Tensor(item),) | |||||
| else: | |||||
| inputs_tensor = (Tensor(inputs),) | |||||
| if isinstance(labels, tuple): | |||||
| labels_tensor = tuple() | |||||
| for item in labels: | |||||
| labels_tensor += (Tensor(item),) | |||||
| else: | |||||
| labels_tensor = (Tensor(labels),) | |||||
| inputs_tensor = to_tensor_tuple(inputs) | |||||
| labels_tensor = to_tensor_tuple(labels) | |||||
| out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) | out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) | ||||
| if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
| out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
| @@ -473,18 +440,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
| >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | ||||
| >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | ||||
| """ | """ | ||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||||
| arr_x = inputs_image | arr_x = inputs_image | ||||
| if self._bounds is not None: | if self._bounds is not None: | ||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| @@ -327,3 +327,22 @@ def check_detection_inputs(inputs, labels): | |||||
| LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
| raise ValueError(msg) | raise ValueError(msg) | ||||
| return images, auxiliary_inputs, gt_boxes, gt_labels | return images, auxiliary_inputs, gt_boxes, gt_labels | ||||
| def check_inputs_labels(inputs, labels): | |||||
| """check inputs and labels is valid for white box method.""" | |||||
| _ = check_param_multi_types('inputs', inputs, (tuple, np.ndarray)) | |||||
| _ = check_param_multi_types('labels', labels, (tuple, np.ndarray)) | |||||
| inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
| if isinstance(inputs, tuple): | |||||
| for i, inputs_item in enumerate(inputs): | |||||
| _ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
| 'inputs[{}]'.format(i), inputs_item) | |||||
| if isinstance(labels, tuple): | |||||
| for i, labels_item in enumerate(labels): | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels[{}]'.format(i), labels_item) | |||||
| else: | |||||
| _ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
| 'labels', labels) | |||||
| return inputs_image, inputs, labels | |||||
| @@ -17,7 +17,7 @@ from mindspore import Tensor | |||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.ops.composite import GradOperation | from mindspore.ops.composite import GradOperation | ||||
| from mindarmour.utils._check_param import check_numpy_param | |||||
| from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types | |||||
| from .logger import LogUtil | from .logger import LogUtil | ||||
| @@ -54,6 +54,44 @@ def jacobian_matrix(grad_wrap_net, inputs, num_classes): | |||||
| return np.asarray(grads_matrix) | return np.asarray(grads_matrix) | ||||
| def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes): | |||||
| """ | |||||
| Calculate the Jacobian matrix for inputs, specifically for object detection model. | |||||
| Args: | |||||
| grad_wrap_net (Cell): A network wrapped by GradWrap. | |||||
| inputs (numpy.ndarray): Input samples. | |||||
| num_boxes (int): Number of boxes infered by each image. | |||||
| num_classes (int): Number of labels of model output. | |||||
| Returns: | |||||
| numpy.ndarray, the Jacobian matrix of inputs. (labels, batch_size, ...) | |||||
| Raises: | |||||
| ValueError: If grad_wrap_net is not a instance of class `GradWrap`. | |||||
| """ | |||||
| if not isinstance(grad_wrap_net, GradWrap): | |||||
| msg = 'grad_wrap_net be and instance of class `GradWrap`.' | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| grad_wrap_net.set_train() | |||||
| grads_matrix = [] | |||||
| inputs_tensor = tuple() | |||||
| if isinstance(inputs, tuple): | |||||
| for item in inputs: | |||||
| inputs_tensor += (Tensor(item),) | |||||
| else: | |||||
| inputs_tensor += (Tensor(inputs),) | |||||
| for idx in range(num_classes): | |||||
| batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] | |||||
| sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32) | |||||
| sens[:, :, idx] = 1.0 | |||||
| grads = grad_wrap_net(*(inputs_tensor), Tensor(sens)) | |||||
| grads_matrix.append(grads.asnumpy()) | |||||
| return np.asarray(grads_matrix) | |||||
| class WithLossCell(Cell): | class WithLossCell(Cell): | ||||
| """ | """ | ||||
| Wrap the network with loss function. | Wrap the network with loss function. | ||||
| @@ -152,19 +190,19 @@ class GradWrap(Cell): | |||||
| self.grad = GradOperation(get_all=False, sens_param=True) | self.grad = GradOperation(get_all=False, sens_param=True) | ||||
| self.network = network | self.network = network | ||||
| def construct(self, inputs, weight): | |||||
| def construct(self, *data): | |||||
| """ | """ | ||||
| Compute jacobian matrix. | Compute jacobian matrix. | ||||
| Args: | Args: | ||||
| inputs (Tensor): Inputs of network. | |||||
| weight (Tensor): Weight of each gradient, `weight` has the same | |||||
| shape with labels. | |||||
| data (Tensor): Data consists of inputs and weight. \ | |||||
| - inputs: Inputs of network. \ | |||||
| - weight: Weight of each gradient, 'weight' has the same shape with labels. | |||||
| Returns: | Returns: | ||||
| Tensor, Jacobian matrix. | Tensor, Jacobian matrix. | ||||
| """ | """ | ||||
| gout = self.grad(self.network)(inputs, weight) | |||||
| gout = self.grad(self.network)(*data) | |||||
| return gout | return gout | ||||
| @@ -199,3 +237,15 @@ def calculate_iou(box_i, box_j): | |||||
| return 0 | return 0 | ||||
| inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) | inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) | ||||
| return inner_area / (s_i + s_j - inner_area) | return inner_area / (s_i + s_j - inner_area) | ||||
| def to_tensor_tuple(inputs_ori): | |||||
| """Transfer inputs data into tensor type.""" | |||||
| inputs_ori = check_param_multi_types('inputs_ori', inputs_ori, [np.ndarray, tuple]) | |||||
| if isinstance(inputs_ori, tuple): | |||||
| inputs_tensor = tuple() | |||||
| for item in inputs_ori: | |||||
| inputs_tensor += (Tensor(item),) | |||||
| else: | |||||
| inputs_tensor = (Tensor(inputs_ori),) | |||||
| return inputs_tensor | |||||
| @@ -54,6 +54,23 @@ class Net(Cell): | |||||
| return out | return out | ||||
| class Net2(Cell): | |||||
| """ | |||||
| Construct the network of target model, specifically for detection model test case. | |||||
| Examples: | |||||
| >>> net = Net2() | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self._softmax = P.Softmax() | |||||
| def construct(self, inputs1, inputs2): | |||||
| out1 = self._softmax(inputs2) | |||||
| out2 = self._softmax(inputs1) | |||||
| return out1, out2 | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @@ -79,6 +96,27 @@ def test_deepfool_attack(): | |||||
| ' implementation error, ms_adv_x != expect_value' | ' implementation error, ms_adv_x != expect_value' | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_deepfool_attack_detection(): | |||||
| """ | |||||
| Deepfool-Attack test | |||||
| """ | |||||
| net = Net2() | |||||
| inputs1_np = np.random.random((2, 10, 10)).astype(np.float32) | |||||
| inputs2_np = np.random.random((2, 10, 5)).astype(np.float32) | |||||
| gt_boxes = inputs1_np[:, :, 0: 5] | |||||
| gt_labels = np.argmax(inputs1_np, axis=2) | |||||
| num_classes = 10 | |||||
| attack = DeepFool(net, num_classes, model_type='detection', reserve_ratio=0.3, | |||||
| bounds=(0.0, 1.0)) | |||||
| _ = attack.generate((inputs1_np, inputs2_np), (gt_boxes, gt_labels)) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||