@@ -42,6 +42,6 @@ Dataset used: [COCO2017](<https://cocodataset.org/>) | |||||
# Quick start | # Quick start | ||||
You can download the pre-trained model checkpoint file [here](<https://www.mindspore.cn/resources/hub/details?2505/MindSpore/ascend/0.7/fasterrcnn_v1.0_coco2017>). | You can download the pre-trained model checkpoint file [here](<https://www.mindspore.cn/resources/hub/details?2505/MindSpore/ascend/0.7/fasterrcnn_v1.0_coco2017>). | ||||
``` | ``` | ||||
python coco_attack_pgd.py --ann_file [VAL_JSON_FILE] --pre_trained [PRETRAINED_CHECKPOINT_FILE] | |||||
python coco_attack_pgd.py --pre_trained [PRETRAINED_CHECKPOINT_FILE] | |||||
``` | ``` | ||||
> Adversarial samples will be generated and saved as pickle file. | > Adversarial samples will be generated and saved as pickle file. |
@@ -33,7 +33,6 @@ from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset | |||||
set_seed(1) | set_seed(1) | ||||
parser = argparse.ArgumentParser(description='FasterRCNN attack') | parser = argparse.ArgumentParser(description='FasterRCNN attack') | ||||
parser.add_argument('--ann_file', type=str, required=True, help='Ann file path.') | |||||
parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.') | parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.') | ||||
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.') | parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.') | ||||
parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.') | parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.') | ||||
@@ -55,7 +54,7 @@ class WithLossCell(Cell): | |||||
self._backbone = backbone | self._backbone = backbone | ||||
self._loss_fn = loss_fn | self._loss_fn = loss_fn | ||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num, *labels): | |||||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | ||||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | ||||
@@ -74,8 +73,8 @@ class GradWrapWithLoss(Cell): | |||||
self._grad_all = GradOperation(get_all=True, sens_param=False) | self._grad_all = GradOperation(get_all=True, sens_param=False) | ||||
self._network = network | self._network = network | ||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||||
gout = self._grad_all(self._network)(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||||
def construct(self, *inputs): | |||||
gout = self._grad_all(self._network)(*inputs) | |||||
return gout[0] | return gout[0] | ||||
@@ -84,7 +83,6 @@ if __name__ == '__main__': | |||||
mindrecord_dir = config.mindrecord_dir | mindrecord_dir = config.mindrecord_dir | ||||
mindrecord_file = os.path.join(mindrecord_dir, prefix) | mindrecord_file = os.path.join(mindrecord_dir, prefix) | ||||
pre_trained = args.pre_trained | pre_trained = args.pre_trained | ||||
ann_file = args.ann_file | |||||
print("CHECKING MINDRECORD FILES ...") | print("CHECKING MINDRECORD FILES ...") | ||||
if not os.path.exists(mindrecord_file): | if not os.path.exists(mindrecord_file): | ||||
@@ -116,7 +114,7 @@ if __name__ == '__main__': | |||||
num = args.num | num = args.num | ||||
num_batches = num // config.test_batch_size | num_batches = num // config.test_batch_size | ||||
channel = 3 | channel = 3 | ||||
adv_samples = [0] * (num_batches * config.test_batch_size) | |||||
adv_samples = [0]*(num_batches*config.test_batch_size) | |||||
adv_id = 0 | adv_id = 0 | ||||
for data in ds.create_dict_iterator(num_epochs=num_batches): | for data in ds.create_dict_iterator(num_epochs=num_batches): | ||||
img_data = data['image'] | img_data = data['image'] | ||||
@@ -125,11 +123,13 @@ if __name__ == '__main__': | |||||
gt_labels = data['label'] | gt_labels = data['label'] | ||||
gt_num = data['valid_num'] | gt_num = data['valid_num'] | ||||
adv_img = attack.generate(img_data.asnumpy(), \ | |||||
(img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy())) | |||||
adv_img = attack.generate((img_data.asnumpy(), \ | |||||
img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy()), gt_labels.asnumpy()) | |||||
for item in adv_img: | for item in adv_img: | ||||
adv_samples[adv_id] = item | adv_samples[adv_id] = item | ||||
adv_id += 1 | adv_id += 1 | ||||
if adv_id >= num_batches*config.test_batch_size: | |||||
break | |||||
pickle.dump(adv_samples, open('adv_samples.pkl', 'wb')) | pickle.dump(adv_samples, open('adv_samples.pkl', 'wb')) | ||||
print('Generate adversarial samples complete.') | print('Generate adversarial samples complete.') |
@@ -41,7 +41,7 @@ class Attack: | |||||
their labels. | their labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Samples based on which adversarial | |||||
inputs (Union[numpy.ndarray, tuple]): Samples based on which adversarial | |||||
examples are generated. | examples are generated. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -55,21 +55,30 @@ class Attack: | |||||
>>> labels = np.array([3, 0]) | >>> labels = np.array([3, 0]) | ||||
>>> advs = attack.batch_generate(inputs, labels, batch_size=2) | >>> advs = attack.batch_generate(inputs, labels, batch_size=2) | ||||
""" | """ | ||||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
if isinstance(inputs, tuple): | |||||
for i, inputs_item in enumerate(inputs): | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'inputs[{}]'.format(i), inputs_item) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
for i, labels_item in enumerate(labels): | for i, labels_item in enumerate(labels): | ||||
arr_x, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'labels[{}]'.format(i), labels_item) | 'labels[{}]'.format(i), labels_item) | ||||
else: | else: | ||||
arr_x, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
'labels', labels) | 'labels', labels) | ||||
arr_x = inputs | |||||
arr_y = labels | arr_y = labels | ||||
len_x = arr_x.shape[0] | |||||
len_x = inputs_image.shape[0] | |||||
batch_size = check_int_positive('batch_size', batch_size) | batch_size = check_int_positive('batch_size', batch_size) | ||||
batches = int(len_x / batch_size) | batches = int(len_x / batch_size) | ||||
rest = len_x - batches*batch_size | rest = len_x - batches*batch_size | ||||
res = [] | res = [] | ||||
for i in range(batches): | for i in range(batches): | ||||
x_batch = arr_x[i*batch_size: (i + 1)*batch_size] | |||||
if isinstance(arr_x, tuple): | |||||
x_batch = tuple([sub_items[i*batch_size: (i + 1)*batch_size] for sub_items in arr_x]) | |||||
else: | |||||
x_batch = arr_x[i*batch_size: (i + 1)*batch_size] | |||||
if isinstance(arr_y, tuple): | if isinstance(arr_y, tuple): | ||||
y_batch = tuple([sub_labels[i*batch_size: (i + 1)*batch_size] for sub_labels in arr_y]) | y_batch = tuple([sub_labels[i*batch_size: (i + 1)*batch_size] for sub_labels in arr_y]) | ||||
else: | else: | ||||
@@ -79,12 +88,14 @@ class Attack: | |||||
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) | res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) | ||||
if rest != 0: | if rest != 0: | ||||
x_batch = arr_x[batches*batch_size:] | |||||
if isinstance(arr_x, tuple): | |||||
x_batch = tuple([sub_items[batches*batch_size:] for sub_items in arr_x]) | |||||
else: | |||||
x_batch = arr_x[batches*batch_size:] | |||||
if isinstance(arr_y, tuple): | if isinstance(arr_y, tuple): | ||||
y_batch = tuple([sub_labels[batches*batch_size:] for sub_labels in arr_y]) | y_batch = tuple([sub_labels[batches*batch_size:] for sub_labels in arr_y]) | ||||
else: | else: | ||||
y_batch = arr_y[batches*batch_size:] | y_batch = arr_y[batches*batch_size:] | ||||
y_batch = arr_y[batches*batch_size:] | |||||
adv_x = self.generate(x_batch, y_batch) | adv_x = self.generate(x_batch, y_batch) | ||||
# Black-attack methods will return 3 values, just get the second. | # Black-attack methods will return 3 values, just get the second. | ||||
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) | res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) | ||||
@@ -98,7 +109,7 @@ class Attack: | |||||
Generate adversarial examples based on normal samples and their labels. | Generate adversarial examples based on normal samples and their labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Samples based on which adversarial | |||||
inputs (Union[numpy.ndarray, tuple]): Samples based on which adversarial | |||||
examples are generated. | examples are generated. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -78,8 +78,7 @@ class SaltAndPepperNoiseAttack(Attack): | |||||
Examples: | Examples: | ||||
>>> adv_list = attack.generate(([[0.1, 0.2, 0.6], | >>> adv_list = attack.generate(([[0.1, 0.2, 0.6], | ||||
>>> [0.3, 0, 0.4]], | >>> [0.3, 0, 0.4]], | ||||
>>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
>>> [0, , 0, 1, 0, 0, 0, 0, 0, 0, 0]]) | |||||
>>> [1, 2]) | |||||
""" | """ | ||||
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', | arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', | ||||
labels) | labels) | ||||
@@ -83,7 +83,7 @@ class GradientMethod(Attack): | |||||
Generate adversarial examples based on input samples and original/target labels. | Generate adversarial examples based on input samples and original/target labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to create | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to create | |||||
adversarial examples. | adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -91,14 +91,19 @@ class GradientMethod(Attack): | |||||
Returns: | Returns: | ||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
""" | """ | ||||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
if isinstance(inputs, tuple): | |||||
for i, inputs_item in enumerate(inputs): | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'inputs[{}]'.format(i), inputs_item) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
for i, labels_item in enumerate(labels): | for i, labels_item in enumerate(labels): | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'labels[{}]'.format(i), labels_item) | 'labels[{}]'.format(i), labels_item) | ||||
else: | else: | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
'labels', labels) | 'labels', labels) | ||||
self._dtype = inputs.dtype | |||||
self._dtype = inputs_image.dtype | |||||
gradient = self._gradient(inputs, labels) | gradient = self._gradient(inputs, labels) | ||||
# use random method or not | # use random method or not | ||||
if self._alpha is not None: | if self._alpha is not None: | ||||
@@ -111,10 +116,10 @@ class GradientMethod(Attack): | |||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
perturbation = perturbation*(clip_max - clip_min) | perturbation = perturbation*(clip_max - clip_min) | ||||
adv_x = inputs + perturbation | |||||
adv_x = inputs_image + perturbation | |||||
adv_x = np.clip(adv_x, clip_min, clip_max) | adv_x = np.clip(adv_x, clip_min, clip_max) | ||||
else: | else: | ||||
adv_x = inputs + perturbation | |||||
adv_x = inputs_image + perturbation | |||||
return adv_x | return adv_x | ||||
@abstractmethod | @abstractmethod | ||||
@@ -123,7 +128,7 @@ class GradientMethod(Attack): | |||||
Calculate gradients based on input samples and original/target labels. | Calculate gradients based on input samples and original/target labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to | |||||
create adversarial examples. | create adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -184,20 +189,26 @@ class FastGradientMethod(GradientMethod): | |||||
Calculate gradients based on input samples and original/target labels. | Calculate gradients based on input samples and original/target labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input sample. | |||||
inputs (Union[numpy.ndarray, tuple]): Input sample. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
Returns: | Returns: | ||||
numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
""" | """ | ||||
if isinstance(inputs, tuple): | |||||
inputs_tensor = tuple() | |||||
for item in inputs: | |||||
inputs_tensor += (Tensor(item),) | |||||
else: | |||||
inputs_tensor = (Tensor(inputs),) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
labels_tensor = tuple() | labels_tensor = tuple() | ||||
for item in labels: | for item in labels: | ||||
labels_tensor += (Tensor(item),) | labels_tensor += (Tensor(item),) | ||||
else: | else: | ||||
labels_tensor = (Tensor(labels),) | labels_tensor = (Tensor(labels),) | ||||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||||
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -297,20 +308,26 @@ class FastGradientSignMethod(GradientMethod): | |||||
labels. | labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input samples. | |||||
labels (union[numpy.ndarray, tuple]): original/target labels. \ | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. | |||||
labels (Union[numpy.ndarray, tuple]): original/target labels. \ | |||||
for each input if it has more than one label, it is wrapped in a tuple. | for each input if it has more than one label, it is wrapped in a tuple. | ||||
Returns: | Returns: | ||||
numpy.ndarray, gradient of inputs. | numpy.ndarray, gradient of inputs. | ||||
""" | """ | ||||
if isinstance(inputs, tuple): | |||||
inputs_tensor = tuple() | |||||
for item in inputs: | |||||
inputs_tensor += (Tensor(item),) | |||||
else: | |||||
inputs_tensor = (Tensor(inputs),) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
labels_tensor = tuple() | labels_tensor = tuple() | ||||
for item in labels: | for item in labels: | ||||
labels_tensor += (Tensor(item),) | labels_tensor += (Tensor(item),) | ||||
else: | else: | ||||
labels_tensor = (Tensor(labels),) | labels_tensor = (Tensor(labels),) | ||||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||||
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -141,7 +141,7 @@ class IterativeGradientMethod(Attack): | |||||
Generate adversarial examples based on input samples and original/target labels. | Generate adversarial examples based on input samples and original/target labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to create | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to create | |||||
adversarial examples. | adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -210,7 +210,7 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
Simple iterative FGSM method to generate adversarial examples. | Simple iterative FGSM method to generate adversarial examples. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to | |||||
create adversarial examples. | create adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -223,36 +223,45 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | ||||
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
if isinstance(inputs, tuple): | |||||
for i, inputs_item in enumerate(inputs): | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'inputs[{}]'.format(i), inputs_item) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
for i, labels_item in enumerate(labels): | for i, labels_item in enumerate(labels): | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'labels[{}]'.format(i), labels_item) | 'labels[{}]'.format(i), labels_item) | ||||
else: | else: | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
'labels', labels) | 'labels', labels) | ||||
arr_x = inputs | |||||
arr_x = inputs_image | |||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
if 'self._prob' in globals(): | if 'self._prob' in globals(): | ||||
d_inputs = _transform_inputs(inputs, self._prob) | |||||
d_inputs = _transform_inputs(inputs_image, self._prob) | |||||
else: | else: | ||||
d_inputs = inputs | |||||
d_inputs = inputs_image | |||||
if isinstance(inputs, tuple): | |||||
d_inputs = (d_inputs,) + inputs[1:] | |||||
adv_x = self._attack.generate(d_inputs, labels) | adv_x = self._attack.generate(d_inputs, labels) | ||||
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
self._eps*clip_diff) | self._eps*clip_diff) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
inputs = adv_x | |||||
inputs_image = adv_x | |||||
else: | else: | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
if 'self._prob' in globals(): | if 'self._prob' in globals(): | ||||
d_inputs = _transform_inputs(inputs, self._prob) | |||||
d_inputs = _transform_inputs(inputs_image, self._prob) | |||||
else: | else: | ||||
d_inputs = inputs | |||||
d_inputs = inputs_image | |||||
if isinstance(inputs, tuple): | |||||
d_inputs = (inputs_image,) + inputs[1:] | |||||
adv_x = self._attack.generate(d_inputs, labels) | adv_x = self._attack.generate(d_inputs, labels) | ||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | |||||
inputs_image = adv_x | |||||
return adv_x | return adv_x | ||||
@@ -299,7 +308,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
Generate adversarial examples based on input data and origin/target labels. | Generate adversarial examples based on input data and origin/target labels. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to | |||||
create adversarial examples. | create adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -313,42 +322,57 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | ||||
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
if isinstance(inputs, tuple): | |||||
for i, inputs_item in enumerate(inputs): | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'inputs[{}]'.format(i), inputs_item) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
for i, labels_item in enumerate(labels): | for i, labels_item in enumerate(labels): | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'labels[{}]'.format(i), labels_item) | 'labels[{}]'.format(i), labels_item) | ||||
else: | else: | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
'labels', labels) | 'labels', labels) | ||||
arr_x = inputs | |||||
arr_x = inputs_image | |||||
momentum = 0 | momentum = 0 | ||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
if 'self._prob' in globals(): | if 'self._prob' in globals(): | ||||
d_inputs = _transform_inputs(inputs, self._prob) | |||||
d_inputs = _transform_inputs(inputs_image, self._prob) | |||||
else: | else: | ||||
d_inputs = inputs | |||||
d_inputs = inputs_image | |||||
if isinstance(inputs, tuple): | |||||
d_inputs = (d_inputs,) + inputs[1:] | |||||
gradient = self._gradient(d_inputs, labels) | gradient = self._gradient(d_inputs, labels) | ||||
momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
if isinstance(d_inputs, tuple): | |||||
adv_x = d_inputs[0] + self._eps_iter*np.sign(momentum) | |||||
else: | |||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
self._eps*clip_diff) | self._eps*clip_diff) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
adv_x = np.clip(adv_x, clip_min, clip_max) | adv_x = np.clip(adv_x, clip_min, clip_max) | ||||
inputs = adv_x | |||||
inputs_image = adv_x | |||||
else: | else: | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
if 'self._prob' in globals(): | if 'self._prob' in globals(): | ||||
d_inputs = _transform_inputs(inputs, self._prob) | |||||
d_inputs = _transform_inputs(inputs_image, self._prob) | |||||
else: | else: | ||||
d_inputs = inputs | |||||
d_inputs = inputs_image | |||||
if isinstance(inputs, tuple): | |||||
d_inputs = (d_inputs,) + inputs[1:] | |||||
gradient = self._gradient(d_inputs, labels) | gradient = self._gradient(d_inputs, labels) | ||||
momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
if isinstance(d_inputs, tuple): | |||||
adv_x = d_inputs[0] + self._eps_iter*np.sign(momentum) | |||||
else: | |||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | |||||
inputs_image = adv_x | |||||
return adv_x | return adv_x | ||||
def _gradient(self, inputs, labels): | def _gradient(self, inputs, labels): | ||||
@@ -356,7 +380,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
Calculate the gradient of input samples. | Calculate the gradient of input samples. | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Input samples. | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. | |||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -368,13 +392,19 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | ||||
""" | """ | ||||
# get grad of loss over x | # get grad of loss over x | ||||
if isinstance(inputs, tuple): | |||||
inputs_tensor = tuple() | |||||
for item in inputs: | |||||
inputs_tensor += (Tensor(item),) | |||||
else: | |||||
inputs_tensor = (Tensor(inputs),) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
labels_tensor = tuple() | labels_tensor = tuple() | ||||
for item in labels: | for item in labels: | ||||
labels_tensor += (Tensor(item),) | labels_tensor += (Tensor(item),) | ||||
else: | else: | ||||
labels_tensor = (Tensor(labels),) | labels_tensor = (Tensor(labels),) | ||||
out_grad = self._loss_grad(Tensor(inputs), *labels_tensor) | |||||
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) | |||||
if isinstance(out_grad, tuple): | if isinstance(out_grad, tuple): | ||||
out_grad = out_grad[0] | out_grad = out_grad[0] | ||||
gradient = out_grad.asnumpy() | gradient = out_grad.asnumpy() | ||||
@@ -429,7 +459,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
perturbation is normalized by projected method with parameter norm_level . | perturbation is normalized by projected method with parameter norm_level . | ||||
Args: | Args: | ||||
inputs (numpy.ndarray): Benign input samples used as references to | |||||
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to | |||||
create adversarial examples. | create adversarial examples. | ||||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | ||||
For each input if it has more than one label, it is wrapped in a tuple. | For each input if it has more than one label, it is wrapped in a tuple. | ||||
@@ -443,14 +473,19 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | ||||
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | ||||
""" | """ | ||||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||||
if isinstance(inputs, tuple): | |||||
for i, inputs_item in enumerate(inputs): | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'inputs[{}]'.format(i), inputs_item) | |||||
if isinstance(labels, tuple): | if isinstance(labels, tuple): | ||||
for i, labels_item in enumerate(labels): | for i, labels_item in enumerate(labels): | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||||
'labels[{}]'.format(i), labels_item) | 'labels[{}]'.format(i), labels_item) | ||||
else: | else: | ||||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||||
'labels', labels) | 'labels', labels) | ||||
arr_x = inputs | |||||
arr_x = inputs_image | |||||
if self._bounds is not None: | if self._bounds is not None: | ||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
@@ -462,7 +497,10 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
perturs = np.clip(perturs, (0 - self._eps)*clip_diff, | perturs = np.clip(perturs, (0 - self._eps)*clip_diff, | ||||
self._eps*clip_diff) | self._eps*clip_diff) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
inputs = adv_x | |||||
if isinstance(inputs, tuple): | |||||
inputs = (adv_x,) + inputs[1:] | |||||
else: | |||||
inputs = adv_x | |||||
else: | else: | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
adv_x = self._attack.generate(inputs, labels) | adv_x = self._attack.generate(inputs, labels) | ||||
@@ -471,7 +509,10 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
norm_level=self._norm_level) | norm_level=self._norm_level) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | |||||
if isinstance(inputs, tuple): | |||||
inputs = (adv_x,) + inputs[1:] | |||||
else: | |||||
inputs = adv_x | |||||
return adv_x | return adv_x | ||||
@@ -580,7 +621,7 @@ def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||||
tran_outputs.append(np.array(p_sample).astype(np.float) / 255) | tran_outputs.append(np.array(p_sample).astype(np.float) / 255) | ||||
if full_aug: | if full_aug: | ||||
# gaussian noise | # gaussian noise | ||||
tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs | |||||
tran_outputs = np.random.normal(np.array(tran_outputs).shape) + tran_outputs | |||||
tran_outputs.extend(raw_inputs) | tran_outputs.extend(raw_inputs) | ||||
if not np.any(tran_outputs-raw_inputs): | if not np.any(tran_outputs-raw_inputs): | ||||
LOGGER.error(TAG, 'the transform function does not take effect.') | LOGGER.error(TAG, 'the transform function does not take effect.') | ||||
@@ -351,8 +351,8 @@ class Fuzzer: | |||||
for param_name in selected_param: | for param_name in selected_param: | ||||
transform.__setattr__('_' + str(param_name), | transform.__setattr__('_' + str(param_name), | ||||
selected_param[param_name]) | selected_param[param_name]) | ||||
mutate_sample = transform.generate([seed[0].astype(np.float32)], | |||||
[seed[1]])[0] | |||||
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), | |||||
np.array([seed[1]]))[0] | |||||
if method not in self._pixel_value_trans_list: | if method not in self._pixel_value_trans_list: | ||||
only_pixel_trans = 1 | only_pixel_trans = 1 | ||||
mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | mutate_sample = [mutate_sample, seed[1], only_pixel_trans] | ||||
@@ -72,7 +72,15 @@ class Net2(Cell): | |||||
def construct(self, inputs1, inputs2): | def construct(self, inputs1, inputs2): | ||||
out1 = self._relu(inputs1) | out1 = self._relu(inputs1) | ||||
out2 = self._relu(inputs2) | out2 = self._relu(inputs2) | ||||
return out1 + out2 | |||||
return out1 + out2, out1 - out2 | |||||
class LossNet(Cell): | |||||
""" | |||||
Loss function for test. | |||||
""" | |||||
def construct(self, loss1, loss2, labels1, labels2): | |||||
return loss1 + loss2 - labels1 - labels2 | |||||
class WithLossCell(Cell): | class WithLossCell(Cell): | ||||
@@ -82,9 +90,9 @@ class WithLossCell(Cell): | |||||
self._backbone = backbone | self._backbone = backbone | ||||
self._loss_fn = loss_fn | self._loss_fn = loss_fn | ||||
def construct(self, inputs1, inputs2, labels): | |||||
def construct(self, inputs1, inputs2, labels1, labels2): | |||||
out = self._backbone(inputs1, inputs2) | out = self._backbone(inputs1, inputs2) | ||||
return self._loss_fn(out, labels) | |||||
return self._loss_fn(*out, labels1, labels2) | |||||
class GradWrapWithLoss(Cell): | class GradWrapWithLoss(Cell): | ||||
@@ -98,8 +106,8 @@ class GradWrapWithLoss(Cell): | |||||
self._grad_all = GradOperation(get_all=True, sens_param=False) | self._grad_all = GradOperation(get_all=True, sens_param=False) | ||||
self._network = network | self._network = network | ||||
def construct(self, inputs1, inputs2, labels): | |||||
gout = self._grad_all(self._network)(inputs1, inputs2, labels) | |||||
def construct(self, *inputs): | |||||
gout = self._grad_all(self._network)(*inputs) | |||||
return gout[0] | return gout[0] | ||||
@@ -285,18 +293,17 @@ def test_fast_gradient_method_multi_inputs(): | |||||
Fast gradient method unit test. | Fast gradient method unit test. | ||||
""" | """ | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
input_np = np.asarray([[0.1, 0.2, 0.7]]).astype(np.float32) | |||||
anno_np = np.asarray([[0.4, 0.8, 0.5]]).astype(np.float32) | |||||
label = np.asarray([2], np.int32) | |||||
label = np.eye(3)[label].astype(np.float32) | |||||
inputs1 = np.asarray([[0.1, 0.2, 0.7]]).astype(np.float32) | |||||
inputs2 = np.asarray([[0.4, 0.8, 0.5]]).astype(np.float32) | |||||
labels1 = np.expand_dims(np.eye(3)[1].astype(np.float32), axis=0) | |||||
labels2 = np.expand_dims(np.eye(3)[2].astype(np.float32), axis=0) | |||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
with_loss_cell = WithLossCell(Net2(), loss_fn) | |||||
with_loss_cell = WithLossCell(Net2(), LossNet()) | |||||
grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | ||||
attack = FastGradientMethod(grad_with_loss_net) | attack = FastGradientMethod(grad_with_loss_net) | ||||
ms_adv_x = attack.generate(input_np, (anno_np, label)) | |||||
ms_adv_x = attack.generate((inputs1, inputs2), (labels1, labels2)) | |||||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | |||||
assert np.any(ms_adv_x != inputs1), 'Fast gradient method: generate value' \ | |||||
' must not be equal to original value.' | ' must not be equal to original value.' | ||||
@@ -332,18 +339,17 @@ def test_batch_generate_multi_inputs(): | |||||
Fast gradient method unit test. | Fast gradient method unit test. | ||||
""" | """ | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
input_np = np.random.random([10, 3]).astype(np.float32) | |||||
anno_np = np.random.random([10, 3]).astype(np.float32) | |||||
label = np.random.randint(0, 3, [10]) | |||||
label = np.eye(3)[label].astype(np.float32) | |||||
inputs1 = np.asarray([[0.1, 0.2, 0.7]]).astype(np.float32) | |||||
inputs2 = np.asarray([[0.4, 0.8, 0.5]]).astype(np.float32) | |||||
labels1 = np.expand_dims(np.eye(3)[1].astype(np.float32), axis=0) | |||||
labels2 = np.expand_dims(np.eye(3)[2].astype(np.float32), axis=0) | |||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
with_loss_cell = WithLossCell(Net2(), loss_fn) | |||||
with_loss_cell = WithLossCell(Net2(), LossNet()) | |||||
grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | ||||
attack = FastGradientMethod(grad_with_loss_net) | attack = FastGradientMethod(grad_with_loss_net) | ||||
ms_adv_x = attack.generate(input_np, (anno_np, label)) | |||||
ms_adv_x = attack.generate((inputs1, inputs2), (labels1, labels2)) | |||||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | |||||
assert np.any(ms_adv_x != inputs1), 'Fast gradient method: generate value' \ | |||||
' must not be equal to original value.' | ' must not be equal to original value.' | ||||