@@ -45,7 +45,7 @@ class GeneticAttack(Attack): | |||||
default: 'classification'. | default: 'classification'. | ||||
targeted (bool): If True, turns on the targeted attack. If False, | targeted (bool): If True, turns on the targeted attack. If False, | ||||
turns on untargeted attack. It should be noted that only untargeted attack | turns on untargeted attack. It should be noted that only untargeted attack | ||||
is supproted for model_type='detection', Default: True. | |||||
is supported for model_type='detection', Default: True. | |||||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | ||||
specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | ||||
pop_size (int): The number of particles, which should be greater than | pop_size (int): The number of particles, which should be greater than | ||||
@@ -69,7 +69,33 @@ class GeneticAttack(Attack): | |||||
c (Union[int, float]): Weight of perturbation loss. Default: 0.1. | c (Union[int, float]): Weight of perturbation loss. Default: 0.1. | ||||
Examples: | Examples: | ||||
>>> attack = GeneticAttack(model) | |||||
>>> import numpy as np | |||||
>>> import mindspore.ops.operations as M | |||||
>>> from mindspore import Tensor | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import GeneticAttack | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._softmax = M.Softmax() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._softmax(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = GeneticAttack(model, sparse=False) | |||||
""" | """ | ||||
def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True, | def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True, | ||||
pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3, | pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3, | ||||
@@ -135,18 +161,77 @@ class GeneticAttack(Attack): | |||||
np.random.random(cur_pop.shape) < prob) + cur_pop | np.random.random(cur_pop.shape) < prob) + cur_pop | ||||
return mutated_pop | return mutated_pop | ||||
def generate(self, inputs, labels): | |||||
def _compute_next_generation(self, cur_pop, fit_vals, x_ori): | |||||
""" | """ | ||||
Generate adversarial examples based on input data and targeted | |||||
labels (or ground_truth labels). | |||||
Compute pop for next generation | |||||
Args: | Args: | ||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
cur_pop (numpy.ndarray): Samples before mutation. | |||||
fit_vals (numpy.ndarray): fitness values | |||||
x_ori (numpy.ndarray): original input x | |||||
Returns: | |||||
numpy.ndarray, pop after generation | |||||
Examples: | |||||
>>> cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) | |||||
""" | |||||
best_fit = max(fit_vals) | |||||
if best_fit > self._best_fit: | |||||
self._best_fit = best_fit | |||||
self._plateau_times = 0 | |||||
else: | |||||
self._plateau_times += 1 | |||||
adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit) | |||||
if self._plateau_times > adap_threshold: | |||||
self._adap_times += 1 | |||||
self._plateau_times = 0 | |||||
if self._adaptive: | |||||
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) | |||||
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) | |||||
else: | |||||
step_noise = self._step_size | |||||
step_p = self._mutation_rate | |||||
step_temp = self._temp | |||||
elite = cur_pop[np.argmax(fit_vals)] | |||||
select_probs = softmax(fit_vals/step_temp) | |||||
select_args = np.arange(self._pop_size) | |||||
parents_arg = np.random.choice( | |||||
a=select_args, size=2*(self._pop_size - 1), | |||||
replace=True, p=select_probs) | |||||
parent1 = cur_pop[parents_arg[:self._pop_size - 1]] | |||||
parent2 = cur_pop[parents_arg[self._pop_size - 1:]] | |||||
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] | |||||
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] | |||||
parent2_probs = parent2_probs / (parent1_probs + parent2_probs) | |||||
# duplicate the probabilities to all features of each particle. | |||||
dims = len(x_ori.shape) | |||||
for _ in range(dims): | |||||
parent2_probs = parent2_probs[:, np.newaxis] | |||||
parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape)) | |||||
cross_probs = (np.random.random(parent1.shape) > | |||||
parent2_probs).astype(np.int32) | |||||
children = parent1*cross_probs + parent2*(1 - cross_probs) | |||||
mutated_children = self._mutation( | |||||
children, step_noise=self._per_bounds*step_noise, | |||||
prob=step_p) | |||||
cur_pop = np.concatenate((mutated_children, elite[np.newaxis, :])) | |||||
return cur_pop, elite | |||||
def _generate_classification(self, inputs, labels): | |||||
""" | |||||
Generate adversarial examples based on input data and | |||||
targeted labels (or ground_truth labels) for classification model. | |||||
Args: | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. | |||||
The format of labels should be numpy.ndarray. | |||||
Returns: | Returns: | ||||
- numpy.ndarray, bool values for each attack result. | - numpy.ndarray, bool values for each attack result. | ||||
@@ -156,28 +241,27 @@ class GeneticAttack(Attack): | |||||
- numpy.ndarray, query times for each sample. | - numpy.ndarray, query times for each sample. | ||||
Examples: | Examples: | ||||
>>> advs = attack.generate([[0.2, 0.3, 0.4], | |||||
>>> [0.3, 0.3, 0.2]], | |||||
>>> [1, 2]) | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack._generate_classification(x_test, y_test) | |||||
""" | """ | ||||
if self._model_type == 'classification': | |||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if self._sparse: | |||||
if labels.size > 1: | |||||
label_squ = np.squeeze(labels) | |||||
else: | |||||
label_squ = labels | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ | |||||
"and got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) | |||||
if self._sparse: | |||||
if labels.size > 1: | |||||
label_squ = np.squeeze(labels) | |||||
else: | else: | ||||
labels = np.argmax(labels, axis=1) | |||||
images = inputs | |||||
elif self._model_type == 'detection': | |||||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||||
label_squ = labels | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ | |||||
"and got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
labels = np.argmax(labels, axis=1) | |||||
images = inputs | |||||
adv_list = [] | adv_list = [] | ||||
success_list = [] | success_list = [] | ||||
@@ -188,17 +272,7 @@ class GeneticAttack(Attack): | |||||
if not self._bounds: | if not self._bounds: | ||||
self._bounds = [np.min(x_ori), np.max(x_ori)] | self._bounds = [np.min(x_ori), np.max(x_ori)] | ||||
pixel_deep = self._bounds[1] - self._bounds[0] | pixel_deep = self._bounds[1] - self._bounds[0] | ||||
if self._model_type == 'classification': | |||||
label_i = labels[i] | |||||
elif self._model_type == 'detection': | |||||
auxiliary_input_i = tuple() | |||||
for item in auxiliary_inputs: | |||||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||||
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) | |||||
inputs_i = (images[i],) + auxiliary_input_i | |||||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||||
label_i = labels[i] | |||||
# generate particles | # generate particles | ||||
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | ||||
@@ -215,106 +289,148 @@ class GeneticAttack(Attack): | |||||
ori_copies + pixel_deep*self._per_bounds), | ori_copies + pixel_deep*self._per_bounds), | ||||
self._bounds[0], self._bounds[1]) | self._bounds[0], self._bounds[1]) | ||||
if self._model_type == 'classification': | |||||
pop_preds = self._model.predict(cur_pop) | |||||
query_times += cur_pop.shape[0] | |||||
all_preds = np.argmax(pop_preds, axis=1) | |||||
if self._targeted: | |||||
success_pop = np.equal(label_i, all_preds).astype(np.int32) | |||||
else: | |||||
success_pop = np.not_equal(label_i, all_preds).astype(np.int32) | |||||
is_success = max(success_pop) | |||||
best_idx = np.argmax(success_pop) | |||||
target_preds = pop_preds[:, label_i] | |||||
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds | |||||
if self._targeted: | |||||
fit_vals = target_preds - others_preds_sum | |||||
else: | |||||
fit_vals = others_preds_sum - target_preds | |||||
elif self._model_type == 'detection': | |||||
confi_adv, correct_nums_adv = self._detection_scores( | |||||
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||||
np.min(correct_nums_adv)) | |||||
query_times += self._pop_size | |||||
fit_vals = abs( | |||||
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( | |||||
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) | |||||
if np.max(fit_vals) < 0: | |||||
self._c /= 2 | |||||
if np.max(fit_vals) < -2: | |||||
LOGGER.debug(TAG, | |||||
'best fitness value is %s, which is too small. We recommend that you decrease ' | |||||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||||
if iters < 3 and np.max(fit_vals) > 100: | |||||
LOGGER.debug(TAG, | |||||
'best fitness value is %s, which is too large. We recommend that you increase ' | |||||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||||
if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): | |||||
is_success = True | |||||
best_idx = np.argmin(correct_nums_adv) | |||||
pop_preds = self._model.predict(cur_pop) | |||||
query_times += cur_pop.shape[0] | |||||
all_preds = np.argmax(pop_preds, axis=1) | |||||
if self._targeted: | |||||
success_pop = np.equal(label_i, all_preds).astype(np.int32) | |||||
else: | |||||
success_pop = np.not_equal(label_i, all_preds).astype(np.int32) | |||||
is_success = max(success_pop) | |||||
best_idx = np.argmax(success_pop) | |||||
target_preds = pop_preds[:, label_i] | |||||
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds | |||||
if self._targeted: | |||||
fit_vals = target_preds - others_preds_sum | |||||
else: | |||||
fit_vals = others_preds_sum - target_preds | |||||
if is_success: | if is_success: | ||||
LOGGER.debug(TAG, 'successfully find one adversarial sample ' | LOGGER.debug(TAG, 'successfully find one adversarial sample ' | ||||
'and start Reduction process.') | 'and start Reduction process.') | ||||
final_adv = cur_pop[best_idx] | final_adv = cur_pop[best_idx] | ||||
if self._model_type == 'classification': | |||||
final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, | |||||
model=self._model, targeted_attack=self._targeted) | |||||
final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, | |||||
model=self._model, targeted_attack=self._targeted) | |||||
break | break | ||||
best_fit = max(fit_vals) | |||||
cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) | |||||
if best_fit > self._best_fit: | |||||
self._best_fit = best_fit | |||||
self._plateau_times = 0 | |||||
else: | |||||
self._plateau_times += 1 | |||||
adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit) | |||||
if self._plateau_times > adap_threshold: | |||||
self._adap_times += 1 | |||||
self._plateau_times = 0 | |||||
if self._adaptive: | |||||
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) | |||||
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) | |||||
else: | |||||
step_noise = self._step_size | |||||
step_p = self._mutation_rate | |||||
step_temp = self._temp | |||||
elite = cur_pop[np.argmax(fit_vals)] | |||||
select_probs = softmax(fit_vals/step_temp) | |||||
select_args = np.arange(self._pop_size) | |||||
parents_arg = np.random.choice( | |||||
a=select_args, size=2*(self._pop_size - 1), | |||||
replace=True, p=select_probs) | |||||
parent1 = cur_pop[parents_arg[:self._pop_size - 1]] | |||||
parent2 = cur_pop[parents_arg[self._pop_size - 1:]] | |||||
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] | |||||
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] | |||||
parent2_probs = parent2_probs / (parent1_probs + parent2_probs) | |||||
# duplicate the probabilities to all features of each particle. | |||||
dims = len(x_ori.shape) | |||||
for _ in range(dims): | |||||
parent2_probs = parent2_probs[:, np.newaxis] | |||||
parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape)) | |||||
cross_probs = (np.random.random(parent1.shape) > | |||||
parent2_probs).astype(np.int32) | |||||
childs = parent1*cross_probs + parent2*(1 - cross_probs) | |||||
mutated_childs = self._mutation( | |||||
childs, step_noise=self._per_bounds*step_noise, | |||||
prob=step_p) | |||||
cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :])) | |||||
if not is_success: | |||||
LOGGER.debug(TAG, 'fail to find adversarial sample.') | |||||
final_adv = elite | |||||
adv_list.append(final_adv) | |||||
LOGGER.debug(TAG, | |||||
'iteration times is: %d and query times is: %d', | |||||
iters, | |||||
query_times) | |||||
success_list.append(is_success) | |||||
query_times_list.append(query_times) | |||||
del ori_copies, cur_pert, cur_pop | |||||
return np.asarray(success_list), \ | |||||
np.asarray(adv_list), \ | |||||
np.asarray(query_times_list) | |||||
def _generate_detection(self, inputs, labels): | |||||
""" | |||||
Generate adversarial examples based on input data and | |||||
targeted labels (or ground_truth labels) for detection model. | |||||
Args: | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be only one array. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be (gt_boxes, gt_labels). | |||||
Returns: | |||||
- numpy.ndarray, bool values for each attack result. | |||||
- numpy.ndarray, generated adversarial examples. | |||||
- numpy.ndarray, query times for each sample. | |||||
Examples: | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack._generate_detection(x_test, y_test) | |||||
""" | |||||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||||
adv_list = [] | |||||
success_list = [] | |||||
query_times_list = [] | |||||
for i in range(images.shape[0]): | |||||
is_success = False | |||||
x_ori = images[i] | |||||
if not self._bounds: | |||||
self._bounds = [np.min(x_ori), np.max(x_ori)] | |||||
pixel_deep = self._bounds[1] - self._bounds[0] | |||||
auxiliary_input_i = tuple() | |||||
for item in auxiliary_inputs: | |||||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||||
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) | |||||
inputs_i = (images[i],) + auxiliary_input_i | |||||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||||
# generate particles | |||||
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||||
# initial perturbations | |||||
cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape) | |||||
cur_pop = ori_copies + cur_pert | |||||
query_times = 0 | |||||
iters = 0 | |||||
while iters < self._max_steps: | |||||
iters += 1 | |||||
cur_pop = np.clip(np.clip(cur_pop, | |||||
ori_copies - pixel_deep*self._per_bounds, | |||||
ori_copies + pixel_deep*self._per_bounds), | |||||
self._bounds[0], self._bounds[1]) | |||||
confi_adv, correct_nums_adv = self._detection_scores( | |||||
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||||
np.min(correct_nums_adv)) | |||||
query_times += self._pop_size | |||||
fit_vals = abs( | |||||
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( | |||||
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) | |||||
if np.max(fit_vals) < 0: | |||||
self._c /= 2 | |||||
if np.max(fit_vals) < -2: | |||||
LOGGER.debug(TAG, | |||||
'best fitness value is %s, which is too small. We recommend that you decrease ' | |||||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||||
if iters < 3 and np.max(fit_vals) > 100: | |||||
LOGGER.debug(TAG, | |||||
'best fitness value is %s, which is too large. We recommend that you increase ' | |||||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||||
if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): | |||||
is_success = True | |||||
best_idx = np.argmin(correct_nums_adv) | |||||
if is_success: | |||||
LOGGER.debug(TAG, 'successfully find one adversarial sample ' | |||||
'and start Reduction process.') | |||||
final_adv = cur_pop[best_idx] | |||||
break | |||||
cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) | |||||
if not is_success: | if not is_success: | ||||
LOGGER.debug(TAG, 'fail to find adversarial sample.') | LOGGER.debug(TAG, 'fail to find adversarial sample.') | ||||
final_adv = elite | final_adv = elite | ||||
if self._model_type == 'detection': | |||||
final_adv, query_times = self._fast_reduction( | |||||
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
final_adv, query_times = self._fast_reduction( | |||||
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||||
adv_list.append(final_adv) | adv_list.append(final_adv) | ||||
LOGGER.debug(TAG, | LOGGER.debug(TAG, | ||||
@@ -327,3 +443,38 @@ class GeneticAttack(Attack): | |||||
return np.asarray(success_list), \ | return np.asarray(success_list), \ | ||||
np.asarray(adv_list), \ | np.asarray(adv_list), \ | ||||
np.asarray(query_times_list) | np.asarray(query_times_list) | ||||
def generate(self, inputs, labels): | |||||
""" | |||||
Generate adversarial examples based on input data and targeted labels (or ground_truth labels). | |||||
Args: | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
Returns: | |||||
- numpy.ndarray, bool values for each attack result. | |||||
- numpy.ndarray, generated adversarial examples. | |||||
- numpy.ndarray, query times for each sample. | |||||
Examples: | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack.generate(x_test, y_test) | |||||
""" | |||||
if self._model_type == 'classification': | |||||
success_list, adv_data, query_time_list = self._generate_classification(inputs, labels) | |||||
elif self._model_type == 'detection': | |||||
success_list, adv_data, query_time_list = self._generate_detection(inputs, labels) | |||||
return success_list, adv_data, query_time_list |
@@ -75,11 +75,26 @@ class HopSkipJumpAttack(Attack): | |||||
ValueError: If constraint not in ['l2', 'linf'] | ValueError: If constraint not in ['l2', 'linf'] | ||||
Examples: | Examples: | ||||
>>> x_test = np.asarray(np.random.random((sample_num, | |||||
>>> sample_length)), np.float32) | |||||
>>> y_test = np.random.randint(0, class_num, size=sample_num) | |||||
>>> instance = HopSkipJumpAttack(user_model) | |||||
>>> adv_x = instance.generate(x_test, y_test) | |||||
>>> import numpy as np | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import HopSkipJumpAttack | |||||
>>> from tests.ut.python.utils.mock_net import Net | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> if len(inputs.shape) == 3: | |||||
>>> inputs = inputs[np.newaxis, :] | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = HopSkipJumpAttack(model) | |||||
""" | """ | ||||
def __init__(self, model, init_num_evals=100, max_num_evals=1000, | def __init__(self, model, init_num_evals=100, max_num_evals=1000, | ||||
@@ -173,7 +188,13 @@ class HopSkipJumpAttack(Attack): | |||||
- numpy.ndarray, query times for each sample. | - numpy.ndarray, query times for each sample. | ||||
Examples: | Examples: | ||||
>>> generate([[0.1,0.2,0.2],[0.2,0.3,0.4]],[2,6]) | |||||
>>> attack = HopSkipJumpAttack(model) | |||||
>>> n, c, h, w = 1, 1, 32, 32 | |||||
>>> class_num = 3 | |||||
>>> x_test = np.asarray(np.random.random((n,c,h,w)), np.float32) | |||||
>>> y_test = np.random.randint(0, class_num, size=n) | |||||
>>> | |||||
>>> _, adv_x, _= attack.generate(x_test, y_test) | |||||
""" | """ | ||||
if labels is not None: | if labels is not None: | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | inputs, labels = check_pair_numpy_param('inputs', inputs, | ||||
@@ -79,16 +79,27 @@ class NES(Attack): | |||||
input labels are one-hot-encoded. Default: True. | input labels are one-hot-encoded. Default: True. | ||||
Examples: | Examples: | ||||
>>> SCENE = 'Label_Only' | |||||
>>> TOP_K = 5 | |||||
>>> num_class = 5 | |||||
>>> nes_instance = NES(user_model, SCENE, top_k=TOP_K) | |||||
>>> initial_img = np.asarray(np.random.random((32, 32)), np.float32) | |||||
>>> target_image = np.asarray(np.random.random((32, 32)), np.float32) | |||||
>>> orig_class = 0 | |||||
>>> target_class = 2 | |||||
>>> nes_instance.set_target_images(target_image) | |||||
>>> tag, adv, queries = nes_instance.generate([initial_img], [target_class]) | |||||
>>> import numpy as np | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import NES | |||||
>>> from tests.ut.python.utils.mock_net import Net | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> if len(inputs.shape) == 3: | |||||
>>> inputs = inputs[np.newaxis, :] | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> SCENE = 'Query_Limit' | |||||
>>> TOP_K = -1 | |||||
>>> attack= NES(model, SCENE, top_k=TOP_K) | |||||
""" | """ | ||||
def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3, | def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3, | ||||
@@ -146,8 +157,19 @@ class NES(Attack): | |||||
ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit'] | ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit'] | ||||
Examples: | Examples: | ||||
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], | |||||
>>> [1, 2]) | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> SCENE = 'Query_Limit' | |||||
>>> TOP_K = -1 | |||||
>>> attack= NES(model, SCENE, top_k=TOP_K) | |||||
>>> | |||||
>>> num_class = 5 | |||||
>>> x_test = np.asarray(np.random.random((32, 32)), np.float32) | |||||
>>> target_image = np.asarray(np.random.random((32, 32)), np.float32) | |||||
>>> orig_class = 0 | |||||
>>> target_class = 2 | |||||
>>> attack.set_target_images(target_image) | |||||
>>> tag, adv, queries = attack.generate(np.array(x_test), np.array([target_class])) | |||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) | inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) | ||||
if not self._sparse: | if not self._sparse: | ||||
@@ -47,6 +47,22 @@ class PointWiseAttack(Attack): | |||||
Default: True. | Default: True. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import PointWiseAttack | |||||
>>> from tests.ut.python.utils.mock_net import Net | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PointWiseAttack(model) | >>> attack = PointWiseAttack(model) | ||||
""" | """ | ||||
@@ -79,7 +95,12 @@ class PointWiseAttack(Attack): | |||||
- numpy.ndarray, query times for each sample. | - numpy.ndarray, query times for each sample. | ||||
Examples: | Examples: | ||||
>>> is_adv_list, adv_list, query_times_each_adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 3]) | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PointWiseAttack(model) | |||||
>>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32) | |||||
>>> y_test = np.random.randint(0, 3, size=1) | |||||
>>> is_adv_list, adv_list, query_times_each_adv = attack.generate(x_test, y_test) | |||||
""" | """ | ||||
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) | arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) | ||||
if not self._sparse: | if not self._sparse: | ||||
@@ -55,7 +55,7 @@ class PSOAttack(Attack): | |||||
clip_max). Default: None. | clip_max). Default: None. | ||||
targeted (bool): If True, turns on the targeted attack. If False, | targeted (bool): If True, turns on the targeted attack. If False, | ||||
turns on untargeted attack. It should be noted that only untargeted attack | turns on untargeted attack. It should be noted that only untargeted attack | ||||
is supproted for model_type='detection', Default: False. | |||||
is supported for model_type='detection', Default: False. | |||||
sparse (bool): If True, input labels are sparse-encoded. If False, | sparse (bool): If True, input labels are sparse-encoded. If False, | ||||
input labels are one-hot-encoded. Default: True. | input labels are one-hot-encoded. Default: True. | ||||
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | ||||
@@ -64,7 +64,35 @@ class PSOAttack(Attack): | |||||
specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | ||||
Examples: | Examples: | ||||
>>> attack = PSOAttack(model) | |||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore import Tensor | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import PSOAttack | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> if len(inputs.shape) == 1: | |||||
>>> inputs = np.expand_dims(inputs, axis=0) | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | |||||
""" | """ | ||||
def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True, | def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True, | ||||
@@ -169,18 +197,33 @@ class PSOAttack(Attack): | |||||
self._bounds[1]) | self._bounds[1]) | ||||
return mutated_pop | return mutated_pop | ||||
def generate(self, inputs, labels): | |||||
def _check_best_fitness(self, best_fitness, iters): | |||||
if best_fitness < -2: | |||||
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' | |||||
'the value of the initialization parameter c.', best_fitness) | |||||
if iters < 3 and best_fitness > 100: | |||||
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' | |||||
'the value of the initialization parameter c.', best_fitness) | |||||
def _update_best_fit_position(self, fit_value, par_best_fit, par_best_poi, par, best_fitness, best_position): | |||||
for k in range(self._pop_size): | |||||
if fit_value[k] > par_best_fit[k]: | |||||
par_best_fit[k] = fit_value[k] | |||||
par_best_poi[k] = par[k] | |||||
if fit_value[k] > best_fitness: | |||||
best_fitness = fit_value[k] | |||||
best_position = par[k].copy() | |||||
return par_best_fit, par_best_poi, best_fitness, best_position | |||||
def _generate_classification(self, inputs, labels): | |||||
""" | """ | ||||
Generate adversarial examples based on input data and targeted | |||||
labels (or ground_truth labels). | |||||
Generate adversarial examples based on input data and | |||||
targeted labels (or ground_truth labels) for classification model. | |||||
Args: | Args: | ||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | ||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
be numpy.ndarray. | |||||
Returns: | Returns: | ||||
- numpy.ndarray, bool values for each attack result. | - numpy.ndarray, bool values for each attack result. | ||||
@@ -190,28 +233,32 @@ class PSOAttack(Attack): | |||||
- numpy.ndarray, query times for each sample. | - numpy.ndarray, query times for each sample. | ||||
Examples: | Examples: | ||||
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], | |||||
>>> [1, 2]) | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack.generate(x_test, y_test) | |||||
""" | """ | ||||
# inputs check | # inputs check | ||||
if self._model_type == 'classification': | |||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if self._sparse: | |||||
if labels.size > 1: | |||||
label_squ = np.squeeze(labels) | |||||
else: | |||||
label_squ = labels | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ | |||||
"got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||||
'labels', labels) | |||||
if self._sparse: | |||||
if labels.size > 1: | |||||
label_squ = np.squeeze(labels) | |||||
else: | else: | ||||
labels = np.argmax(labels, axis=1) | |||||
images = inputs | |||||
elif self._model_type == 'detection': | |||||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||||
label_squ = labels | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ | |||||
"got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
labels = np.argmax(labels, axis=1) | |||||
images = inputs | |||||
# generate one adversarial each time | # generate one adversarial each time | ||||
adv_list = [] | adv_list = [] | ||||
@@ -226,17 +273,9 @@ class PSOAttack(Attack): | |||||
pixel_deep = self._bounds[1] - self._bounds[0] | pixel_deep = self._bounds[1] - self._bounds[0] | ||||
q_times += 1 | q_times += 1 | ||||
if self._model_type == 'classification': | |||||
label_i = labels[i] | |||||
confi_ori = self._confidence_cla(x_ori, label_i) | |||||
elif self._model_type == 'detection': | |||||
auxiliary_input_i = tuple() | |||||
for item in auxiliary_inputs: | |||||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||||
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) | |||||
inputs_i = (images[i],) + auxiliary_input_i | |||||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) | |||||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||||
label_i = labels[i] | |||||
confi_ori = self._confidence_cla(x_ori, label_i) | |||||
# step1, initializing | # step1, initializing | ||||
# initial global optimum fitness value, cannot set to be -inf | # initial global optimum fitness value, cannot set to be -inf | ||||
@@ -277,57 +316,178 @@ class PSOAttack(Attack): | |||||
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), | x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), | ||||
self._bounds[0], self._bounds[1]) | self._bounds[0], self._bounds[1]) | ||||
if self._model_type == 'classification': | |||||
confi_adv = self._confidence_cla(par, label_i) | |||||
elif self._model_type == 'detection': | |||||
confi_adv, _ = self._detection_scores( | |||||
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
confi_adv = self._confidence_cla(par, label_i) | |||||
q_times += self._pop_size | q_times += self._pop_size | ||||
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) | fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) | ||||
for k in range(self._pop_size): | |||||
if fit_value[k] > par_best_fit[k]: | |||||
par_best_fit[k] = fit_value[k] | |||||
par_best_poi[k] = par[k] | |||||
if fit_value[k] > best_fitness: | |||||
best_fitness = fit_value[k] | |||||
best_position = par[k].copy() | |||||
par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value, | |||||
par_best_fit, | |||||
par_best_poi, | |||||
par, | |||||
best_fitness, | |||||
best_position) | |||||
iters += 1 | iters += 1 | ||||
if best_fitness < -2: | |||||
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' | |||||
'the value of the initialization parameter c.', best_fitness) | |||||
if iters < 3 and best_fitness > 100: | |||||
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' | |||||
'the value of the initialization parameter c.', best_fitness) | |||||
self._check_best_fitness(best_fitness, iters) | |||||
is_mutation = False | is_mutation = False | ||||
if (best_fitness - last_best_fit) < last_best_fit*0.05: | if (best_fitness - last_best_fit) < last_best_fit*0.05: | ||||
is_mutation = True | is_mutation = True | ||||
q_times += 1 | q_times += 1 | ||||
if self._model_type == 'classification': | |||||
cur_pre = self._model.predict(best_position) | |||||
cur_label = np.argmax(cur_pre) | |||||
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): | |||||
is_success = True | |||||
elif self._model_type == 'detection': | |||||
_, correct_nums_adv = self._detection_scores( | |||||
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||||
correct_nums_adv[0]) | |||||
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): | |||||
is_success = True | |||||
cur_pre = self._model.predict(best_position) | |||||
cur_label = np.argmax(cur_pre) | |||||
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): | |||||
is_success = True | |||||
if is_success: | if is_success: | ||||
LOGGER.debug(TAG, 'successfully find one adversarial ' | LOGGER.debug(TAG, 'successfully find one adversarial ' | ||||
'sample and start Reduction process') | 'sample and start Reduction process') | ||||
# step3, reduction | # step3, reduction | ||||
if self._model_type == 'classification': | |||||
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, | |||||
targeted_attack=self._targeted) | |||||
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, | |||||
targeted_attack=self._targeted) | |||||
break | |||||
if not is_success: | |||||
LOGGER.debug(TAG, | |||||
'fail to find adversarial sample, iteration ' | |||||
'times is: %d and query times is: %d', | |||||
iters, | |||||
q_times) | |||||
adv_list.append(best_position) | |||||
success_list.append(is_success) | |||||
query_times_list.append(q_times) | |||||
del x_copies, cur_noise, par, par_best_poi | |||||
return np.asarray(success_list), \ | |||||
np.asarray(adv_list), \ | |||||
np.asarray(query_times_list) | |||||
def _generate_detection(self, inputs, labels): | |||||
""" | |||||
Generate adversarial examples based on input data and | |||||
targeted labels (or ground_truth labels) for detection model. | |||||
Args: | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) | |||||
or only one array. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. | |||||
The format of labels should be (gt_boxes, gt_labels). | |||||
Returns: | |||||
- numpy.ndarray, bool values for each attack result. | |||||
- numpy.ndarray, generated adversarial examples. | |||||
- numpy.ndarray, query times for each sample. | |||||
Examples: | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack.generate(x_test, y_test) | |||||
""" | |||||
# inputs check | |||||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||||
# generate one adversarial each time | |||||
adv_list = [] | |||||
success_list = [] | |||||
query_times_list = [] | |||||
for i in range(images.shape[0]): | |||||
is_success = False | |||||
q_times = 0 | |||||
x_ori = images[i] | |||||
if not self._bounds: | |||||
self._bounds = [np.min(x_ori), np.max(x_ori)] | |||||
pixel_deep = self._bounds[1] - self._bounds[0] | |||||
q_times += 1 | |||||
auxiliary_input_i = tuple() | |||||
for item in auxiliary_inputs: | |||||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||||
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) | |||||
inputs_i = (images[i],) + auxiliary_input_i | |||||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) | |||||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||||
# step1, initializing | |||||
# initial global optimum fitness value, cannot set to be -inf | |||||
best_fitness = -np.inf | |||||
# initial global optimum position | |||||
best_position = x_ori | |||||
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||||
cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep, | |||||
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1), | |||||
self._per_bounds*(np.abs(x_copies) + 0.1)) | |||||
# initial advs | |||||
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) | |||||
# initial optimum positions for particles | |||||
par_best_poi = np.copy(par) | |||||
# initial optimum fitness values | |||||
par_best_fit = -np.inf*np.ones(self._pop_size) | |||||
# step2, optimization | |||||
# initial velocities for particles | |||||
v_particles = np.zeros(par.shape) | |||||
is_mutation = False | |||||
iters = 0 | |||||
while iters < self._t_max: | |||||
last_best_fit = best_fitness | |||||
ran_1 = np.random.random(par.shape) | |||||
ran_2 = np.random.random(par.shape) | |||||
v_particles = self._step_size*( | |||||
v_particles + self._c1*ran_1*(best_position - par)) \ | |||||
+ self._c2*ran_2*(par_best_poi - par) | |||||
par += v_particles | |||||
if iters > 6 and is_mutation: | |||||
par = self._mutation_op(par) | |||||
par = np.clip(np.clip(par, | |||||
x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds, | |||||
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), | |||||
self._bounds[0], self._bounds[1]) | |||||
confi_adv, _ = self._detection_scores( | |||||
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
q_times += self._pop_size | |||||
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) | |||||
par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value, | |||||
par_best_fit, | |||||
par_best_poi, | |||||
par, | |||||
best_fitness, | |||||
best_position) | |||||
iters += 1 | |||||
self._check_best_fitness(best_fitness, iters) | |||||
is_mutation = False | |||||
if (best_fitness - last_best_fit) < last_best_fit*0.05: | |||||
is_mutation = True | |||||
q_times += 1 | |||||
_, correct_nums_adv = self._detection_scores( | |||||
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||||
correct_nums_adv[0]) | |||||
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): | |||||
is_success = True | |||||
if is_success: | |||||
LOGGER.debug(TAG, 'successfully find one adversarial ' | |||||
'sample and start Reduction process') | |||||
break | break | ||||
if self._model_type == 'detection': | |||||
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, | |||||
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, | |||||
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||||
if not is_success: | if not is_success: | ||||
LOGGER.debug(TAG, | LOGGER.debug(TAG, | ||||
'fail to find adversarial sample, iteration ' | 'fail to find adversarial sample, iteration ' | ||||
@@ -341,3 +501,43 @@ class PSOAttack(Attack): | |||||
return np.asarray(success_list), \ | return np.asarray(success_list), \ | ||||
np.asarray(adv_list), \ | np.asarray(adv_list), \ | ||||
np.asarray(query_times_list) | np.asarray(query_times_list) | ||||
def generate(self, inputs, labels): | |||||
""" | |||||
Generate adversarial examples based on input data and | |||||
targeted labels (or ground_truth labels). | |||||
Args: | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
Returns: | |||||
- numpy.ndarray, bool values for each attack result. | |||||
- numpy.ndarray, generated adversarial examples. | |||||
- numpy.ndarray, query times for each sample. | |||||
Examples: | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | |||||
>>> batch_size = 6 | |||||
>>> x_test = np.random.rand(batch_size, 10) | |||||
>>> y_test = np.random.randint(low=0, high=10, size=batch_size) | |||||
>>> y_test = np.eye(10)[y_test] | |||||
>>> y_test = y_test.astype(np.float32) | |||||
>>> _, adv_data, _ = attack.generate(x_test, y_test) | |||||
""" | |||||
# inputs check | |||||
if self._model_type == 'classification': | |||||
success_list, adv_data, query_time_list = self._generate_classification(inputs, labels) | |||||
elif self._model_type == 'detection': | |||||
success_list, adv_data, query_time_list = self._generate_detection(inputs, labels) | |||||
return success_list, adv_data, query_time_list |
@@ -40,6 +40,22 @@ class SaltAndPepperNoiseAttack(Attack): | |||||
Default: True. | Default: True. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour import BlackModel | |||||
>>> from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack | |||||
>>> from tests.ut.python.utils.mock_net import Net | |||||
>>> | |||||
>>> class ModelToBeAttacked(BlackModel): | |||||
>>> def __init__(self, network): | |||||
>>> super(ModelToBeAttacked, self).__init__() | |||||
>>> self._network = network | |||||
>>> def predict(self, inputs): | |||||
>>> result = self._network(Tensor(inputs.astype(np.float32))) | |||||
>>> return result.asnumpy() | |||||
>>> | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = SaltAndPepperNoiseAttack(model) | >>> attack = SaltAndPepperNoiseAttack(model) | ||||
""" | """ | ||||
@@ -69,7 +85,12 @@ class SaltAndPepperNoiseAttack(Attack): | |||||
- numpy.ndarray, query times for each sample. | - numpy.ndarray, query times for each sample. | ||||
Examples: | Examples: | ||||
>>> adv_list = attack.generate(([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2]) | |||||
>>> net = Net() | |||||
>>> model = ModelToBeAttacked(net) | |||||
>>> attack = PointWiseAttack(model) | |||||
>>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32) | |||||
>>> y_test = np.random.randint(0, 3, size=1) | |||||
>>> _, adv_list, _ = attack.generate(x_test, y_test) | |||||
""" | """ | ||||
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) | arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) | ||||
if not self._sparse: | if not self._sparse: | ||||
@@ -95,7 +95,24 @@ class CarliniWagnerL2Attack(Attack): | |||||
input labels are onehot-coded. Default: True. | input labels are onehot-coded. Default: True. | ||||
Examples: | Examples: | ||||
>>> attack = CarliniWagnerL2Attack(network) | |||||
>>> import numpy as np | |||||
>>> import mindspore.ops.operations as M | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._softmax = M.Softmax() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._softmax(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) | |||||
>>> label_np = np.array([3]).astype(np.int64) | |||||
>>> num_classes = input_np.shape[1] | |||||
>>> attack = CarliniWagnerL2Attack(net, num_classes, targeted=False) | |||||
""" | """ | ||||
def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, | def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, | ||||
@@ -246,6 +263,14 @@ class CarliniWagnerL2Attack(Attack): | |||||
the_grad = the_grad*diff | the_grad = the_grad*diff | ||||
return inputs, the_grad | return inputs, the_grad | ||||
def _check_success(self, logits, labels): | |||||
""" check if attack success (include all examples)""" | |||||
if self._targeted: | |||||
is_adv = (np.argmax(logits, axis=1) == labels) | |||||
else: | |||||
is_adv = (np.argmax(logits, axis=1) != labels) | |||||
return is_adv | |||||
def generate(self, inputs, labels): | def generate(self, inputs, labels): | ||||
""" | """ | ||||
Generate adversarial examples based on input data and targeted labels. | Generate adversarial examples based on input data and targeted labels. | ||||
@@ -259,7 +284,30 @@ class CarliniWagnerL2Attack(Attack): | |||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
Examples: | Examples: | ||||
>>> advs = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2]] | |||||
>>> import numpy as np | |||||
>>> import mindspore.ops.operations as M | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._softmax = M.Softmax() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._softmax(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) | |||||
>>> num_classes = input_np.shape[1] | |||||
>>> | |||||
>>> label_np = np.array([3]).astype(np.int64) | |||||
>>> attack_nonTargeted = CarliniWagnerL2Attack(net, num_classes, targeted=False) | |||||
>>> advs_nonTargeted = attack_nonTargeted.generate(input_np, label_np) | |||||
>>> | |||||
>>> target_np = np.array([1]).astype(np.int64) | |||||
>>> attack_targeted = CarliniWagnerL2Attack(net, num_classes, targeted=False) | |||||
>>> advs_targeted = attack_targeted.generate(input_np, target_np) | |||||
""" | """ | ||||
LOGGER.debug(TAG, "enter the func generate.") | LOGGER.debug(TAG, "enter the func generate.") | ||||
@@ -302,11 +350,7 @@ class CarliniWagnerL2Attack(Attack): | |||||
logits, x_input, reconstructed_original, | logits, x_input, reconstructed_original, | ||||
labels, const, self._confidence) | labels, const, self._confidence) | ||||
# check if attack success (include all examples) | |||||
if self._targeted: | |||||
is_adv = (np.argmax(logits, axis=1) == labels) | |||||
else: | |||||
is_adv = (np.argmax(logits, axis=1) != labels) | |||||
is_adv = self._check_success(logits, labels) | |||||
for i in range(samples_num): | for i in range(samples_num): | ||||
if is_adv[i]: | if is_adv[i]: | ||||
@@ -117,7 +117,23 @@ class DeepFool(Attack): | |||||
input labels are onehot-coded. Default: True. | input labels are onehot-coded. Default: True. | ||||
Examples: | Examples: | ||||
>>> attack = DeepFool(network) | |||||
>>> import numpy as np | |||||
>>> import mindspore.ops.operations as P | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour.adv_robustness.attacks import DeepFool | |||||
>>> | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._softmax = P.Softmax() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._softmax(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = DeepFool(net, classes, max_iters=10, norm_level=2, | |||||
bounds=(0.0, 1.0)) | |||||
""" | """ | ||||
def __init__(self, network, num_classes, model_type='classification', | def __init__(self, network, num_classes, model_type='classification', | ||||
@@ -165,14 +181,30 @@ class DeepFool(Attack): | |||||
NotImplementedError: If norm_level is not in [2, np.inf, '2', 'inf']. | NotImplementedError: If norm_level is not in [2, np.inf, '2', 'inf']. | ||||
Examples: | Examples: | ||||
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) | |||||
>>> input_shape = (1, 5) | |||||
>>> _, classes = input_shape | |||||
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) | |||||
>>> input_me = Tensor(input_np) | |||||
>>> true_labels = np.argmax(net(input_me).asnumpy(), axis=1) | |||||
>>> attack = DeepFool(net, classes, max_iters=10, norm_level=2, bounds=(0.0, 1.0)) | |||||
>>> advs = attack.generate(input_np, true_labels) | |||||
""" | """ | ||||
if self._model_type == 'detection': | if self._model_type == 'detection': | ||||
return self._generate_detection(inputs, labels) | return self._generate_detection(inputs, labels) | ||||
if self._model_type == 'classification': | if self._model_type == 'classification': | ||||
return self._generate_classification(inputs, labels) | return self._generate_classification(inputs, labels) | ||||
return None | return None | ||||
def _update_image(self, x_origin, r_tot): | |||||
"""update image based on bounds""" | |||||
if self._bounds is not None: | |||||
clip_min, clip_max = self._bounds | |||||
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) | |||||
images = np.clip(images, clip_min, clip_max) | |||||
else: | |||||
images = x_origin + (1 + self._overshoot)*r_tot | |||||
return images | |||||
def _generate_detection(self, inputs, labels): | def _generate_detection(self, inputs, labels): | ||||
"""Generate adversarial examples in detection scenario""" | """Generate adversarial examples in detection scenario""" | ||||
@@ -239,19 +271,12 @@ class DeepFool(Attack): | |||||
raise NotImplementedError(msg) | raise NotImplementedError(msg) | ||||
r_tot[idx, ...] = r_tot[idx, ...] + r_i | r_tot[idx, ...] = r_tot[idx, ...] + r_i | ||||
if self._bounds is not None: | |||||
clip_min, clip_max = self._bounds | |||||
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) | |||||
images = np.clip(images, clip_min, clip_max) | |||||
else: | |||||
images = x_origin + (1 + self._overshoot)*r_tot | |||||
images = self._update_image(x_origin, r_tot) | |||||
iteration += 1 | iteration += 1 | ||||
images = images.astype(images_dtype) | images = images.astype(images_dtype) | ||||
del preds_logits, grads | del preds_logits, grads | ||||
return images | return images | ||||
def _generate_classification(self, inputs, labels): | def _generate_classification(self, inputs, labels): | ||||
"""Generate adversarial examples in classification scenario""" | """Generate adversarial examples in classification scenario""" | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | inputs, labels = check_pair_numpy_param('inputs', inputs, | ||||
@@ -47,9 +47,25 @@ class GradientMethod(Attack): | |||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindspore import Tensor | |||||
>>> from mindarmour.adv_robustness.attacksimport FastGradientMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> net = Net() | |||||
>>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -155,9 +171,24 @@ class FastGradientMethod(GradientMethod): | |||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import FastGradientMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> net = Net() | |||||
>>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -223,9 +254,24 @@ class RandomFastGradientMethod(FastGradientMethod): | |||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> attack = RandomFastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -265,9 +311,24 @@ class FastGradientSignMethod(GradientMethod): | |||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> attack = FastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -329,9 +390,24 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientSignMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> attack = RandomFastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate(inputs, labels) | >>> adv_x = attack.generate(inputs, labels) | ||||
""" | """ | ||||
@@ -366,6 +442,21 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import LeastLikelyClassMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | >>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | ||||
@@ -404,6 +495,21 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): | |||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import RandomLeastLikelyClassMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | ||||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | ||||
>>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | >>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | ||||
@@ -184,7 +184,22 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
is already equipped with loss function. Default: None. | is already equipped with loss function. Default: None. | ||||
Examples: | Examples: | ||||
>>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import BasicIterativeMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = BasicIterativeMethod(netw, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
is_targeted=False, nb_iter=5, loss_fn=None): | is_targeted=False, nb_iter=5, loss_fn=None): | ||||
@@ -215,6 +230,17 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
Examples: | Examples: | ||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = BasicIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate([[0.3, 0.2, 0.6], | >>> adv_x = attack.generate([[0.3, 0.2, 0.6], | ||||
>>> [0.3, 0.2, 0.4]], | >>> [0.3, 0.2, 0.4]], | ||||
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | ||||
@@ -303,6 +329,22 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import MomentumIterativeMethod | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = MomentumIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate([[0.5, 0.2, 0.6], | >>> adv_x = attack.generate([[0.5, 0.2, 0.6], | ||||
>>> [0.3, 0, 0.2]], | >>> [0.3, 0, 0.2]], | ||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | ||||
@@ -433,6 +475,22 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||||
>>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent | |||||
>>> | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = ProjectedGradientDescent(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||||
>>> adv_x = attack.generate([[0.6, 0.2, 0.6], | >>> adv_x = attack.generate([[0.6, 0.2, 0.6], | ||||
>>> [0.3, 0.3, 0.4]], | >>> [0.3, 0.3, 0.4]], | ||||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | ||||
@@ -54,7 +54,23 @@ class JSMAAttack(Attack): | |||||
input labels are onehot-coded. Default: True. | input labels are onehot-coded. Default: True. | ||||
Examples: | Examples: | ||||
>>> attack = JSMAAttack(network) | |||||
>>> import numpy as np | |||||
>>> import mindspore.nn as nn | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour.adv_robustness.attacks import JSMAAttack | |||||
>>> class Net(Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self._relu = nn.ReLU() | |||||
>>> | |||||
>>> def construct(self, inputs): | |||||
>>> out = self._relu(inputs) | |||||
>>> return out | |||||
>>> | |||||
>>> net = Net() | |||||
>>> input_shape = (1, 5) | |||||
>>> batch_size, classes = input_shape | |||||
>>> attack = JSMAAttack(net, classes, max_iteration=5) | |||||
""" | """ | ||||
def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, | def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, | ||||
@@ -181,7 +197,13 @@ class JSMAAttack(Attack): | |||||
numpy.ndarray, adversarial samples. | numpy.ndarray, adversarial samples. | ||||
Examples: | Examples: | ||||
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) | |||||
>>> input_shape = (1, 5) | |||||
>>> input_np = np.random.random(input_shape).astype(np.float32) | |||||
>>> label_np = np.random.randint(classes, size=batch_size) | |||||
>>> batch_size, classes = input_shape | |||||
>>> | |||||
>>> attack = JSMAAttack(net, classes, max_iteration=5) | |||||
>>> advs = attack.generate(input_np, label_np) | |||||
""" | """ | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | inputs, labels = check_pair_numpy_param('inputs', inputs, | ||||
'labels', labels) | 'labels', labels) | ||||
@@ -54,7 +54,12 @@ class LBFGS(Attack): | |||||
input labels are onehot-coded. Default: False. | input labels are onehot-coded. Default: False. | ||||
Examples: | Examples: | ||||
>>> attack = LBFGS(network) | |||||
>>> import numpy as np | |||||
>>> from mindarmour.adv_robustness.attacks import LBFGS | |||||
>>> from tests.ut.python.utils.mock_net import Net | |||||
>>> | |||||
>>> net = Net() | |||||
>>> attack = LBFGS(net, is_targeted=True) | |||||
""" | """ | ||||
def __init__(self, network, eps=1e-5, bounds=(0.0, 1.0), is_targeted=True, | def __init__(self, network, eps=1e-5, bounds=(0.0, 1.0), is_targeted=True, | ||||
nb_iter=150, search_iters=30, loss_fn=None, sparse=False): | nb_iter=150, search_iters=30, loss_fn=None, sparse=False): | ||||
@@ -94,6 +99,7 @@ class LBFGS(Attack): | |||||
numpy.ndarray, generated adversarial examples. | numpy.ndarray, generated adversarial examples. | ||||
Examples: | Examples: | ||||
>>> attack = LBFGS(net, is_targeted=True) | |||||
>>> adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 2]) | >>> adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 2]) | ||||
""" | """ | ||||
LOGGER.debug(TAG, 'start to generate adv image.') | LOGGER.debug(TAG, 'start to generate adv image.') | ||||
@@ -191,7 +197,7 @@ class LBFGS(Attack): | |||||
def _optimize(self, start_input, labels, epsilon): | def _optimize(self, start_input, labels, epsilon): | ||||
""" | """ | ||||
Given loss fuction and gradient, use l_bfgs_b algorithm to update input | |||||
Given loss function and gradient, use l_bfgs_b algorithm to update input | |||||
sample. The epsilon will be doubled until an adversarial example is found. | sample. The epsilon will be doubled until an adversarial example is found. | ||||
Args: | Args: | ||||
@@ -28,7 +28,7 @@ from tests.ut.python.utils.mock_net import Net | |||||
context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
TAG = 'HopSkipJumpAttack' | |||||
TAG = 'NaturalEvolutionaryStrategy' | |||||
class ModelToBeAttacked(BlackModel): | class ModelToBeAttacked(BlackModel): | ||||
@@ -100,7 +100,7 @@ def get_dataset(current_dir): | |||||
def nes_mnist_attack(scene, top_k): | def nes_mnist_attack(scene, top_k): | ||||
""" | """ | ||||
hsja-Attack test | |||||
nes-Attack test | |||||
""" | """ | ||||
current_dir = os.path.dirname(os.path.abspath(__file__)) | current_dir = os.path.dirname(os.path.abspath(__file__)) | ||||
test_images, test_labels = get_dataset(current_dir) | test_images, test_labels = get_dataset(current_dir) | ||||