From: @jxlang910 Reviewed-by: @pkuliuliu,@liu_luobin Signed-off-by: @pkuliuliutags/v1.1.0
@@ -182,11 +182,11 @@ class Attack: | |||||
best_position = check_numpy_param('best_position', best_position) | best_position = check_numpy_param('best_position', best_position) | ||||
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | ||||
x_shape = best_position.shape | x_shape = best_position.shape | ||||
reduction_iters = 10000 # recover 0.01% each step | |||||
reduction_iters = 1000 # recover 0.1% each step | |||||
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | _, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | ||||
for _ in range(reduction_iters): | for _ in range(reduction_iters): | ||||
diff = x_ori - best_position | diff = x_ori - best_position | ||||
res = 0.5*diff*(np.random.random(x_shape) < 0.0001) | |||||
res = 0.5*diff*(np.random.random(x_shape) < 0.001) | |||||
best_position += res | best_position += res | ||||
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | ||||
q_times += 1 | q_times += 1 | ||||
@@ -46,16 +46,16 @@ class GeneticAttack(Attack): | |||||
targeted (bool): If True, turns on the targeted attack. If False, | targeted (bool): If True, turns on the targeted attack. If False, | ||||
turns on untargeted attack. It should be noted that only untargeted attack | turns on untargeted attack. It should be noted that only untargeted attack | ||||
is supproted for model_type='detection', Default: False. | is supproted for model_type='detection', Default: False. | ||||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||||
model_type='detection'. Default: 0.3. | |||||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | |||||
specifically for model_type='detection'. Default: 0.3. | |||||
pop_size (int): The number of particles, which should be greater than | pop_size (int): The number of particles, which should be greater than | ||||
zero. Default: 6. | zero. Default: 6. | ||||
mutation_rate (float): The probability of mutations. Default: 0.005. | |||||
per_bounds (float): Maximum L_inf distance. | |||||
mutation_rate (Union[int, float]): The probability of mutations. Default: 0.005. | |||||
per_bounds (Union[int, float]): Maximum L_inf distance. | |||||
max_steps (int): The maximum round of iteration for each adversarial | max_steps (int): The maximum round of iteration for each adversarial | ||||
example. Default: 1000. | example. Default: 1000. | ||||
step_size (float): Attack step size. Default: 0.2. | |||||
temp (float): Sampling temperature for selection. Default: 0.3. | |||||
step_size (Union[int, float]): Attack step size. Default: 0.2. | |||||
temp (Union[int, float]): Sampling temperature for selection. Default: 0.3. | |||||
The greater the temp, the greater the differences between individuals' | The greater the temp, the greater the differences between individuals' | ||||
selecting probabilities. | selecting probabilities. | ||||
bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form | bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form | ||||
@@ -65,7 +65,7 @@ class GeneticAttack(Attack): | |||||
Default: False. | Default: False. | ||||
sparse (bool): If True, input labels are sparse-encoded. If False, | sparse (bool): If True, input labels are sparse-encoded. If False, | ||||
input labels are one-hot-encoded. Default: True. | input labels are one-hot-encoded. Default: True. | ||||
c (float): Weight of perturbation loss. Default: 0.1. | |||||
c (Union[int, float]): Weight of perturbation loss. Default: 0.1. | |||||
Examples: | Examples: | ||||
>>> attack = GeneticAttack(model) | >>> attack = GeneticAttack(model) | ||||
@@ -76,6 +76,10 @@ class GeneticAttack(Attack): | |||||
super(GeneticAttack, self).__init__() | super(GeneticAttack, self).__init__() | ||||
self._model = check_model('model', model, BlackModel) | self._model = check_model('model', model, BlackModel) | ||||
self._model_type = check_param_type('model_type', model_type, str) | self._model_type = check_param_type('model_type', model_type, str) | ||||
if self._model_type not in ('classification', 'detection'): | |||||
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
self._targeted = check_param_type('targeted', targeted, bool) | self._targeted = check_param_type('targeted', targeted, bool) | ||||
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | ||||
if self._reserve_ratio > 1: | if self._reserve_ratio > 1: | ||||
@@ -153,10 +157,14 @@ class GeneticAttack(Attack): | |||||
if self._model_type == 'classification': | if self._model_type == 'classification': | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | inputs, labels = check_pair_numpy_param('inputs', inputs, | ||||
'labels', labels) | 'labels', labels) | ||||
if not self._sparse: | |||||
if labels.ndim != 2: | |||||
raise ValueError('labels must be 2 dims, ' | |||||
'but got {} dims.'.format(labels.ndim)) | |||||
if self._sparse: | |||||
label_squ = np.squeeze(labels) | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ | |||||
"and got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError | |||||
else: | |||||
labels = np.argmax(labels, axis=1) | labels = np.argmax(labels, axis=1) | ||||
images = inputs | images = inputs | ||||
elif self._model_type == 'detection': | elif self._model_type == 'detection': | ||||
@@ -41,17 +41,17 @@ class PSOAttack(Attack): | |||||
Args: | Args: | ||||
model (BlackModel): Target model. | model (BlackModel): Target model. | ||||
step_size (float): Attack step size. Default: 0.5. | |||||
per_bounds (float): Relative variation range of perturbations. Default: 0.6. | |||||
c1 (float): Weight coefficient. Default: 2. | |||||
c2 (float): Weight coefficient. Default: 2. | |||||
c (float): Weight of perturbation loss. Default: 2. | |||||
step_size (Union[int, float]): Attack step size. Default: 0.5. | |||||
per_bounds (Union[int, float]): Relative variation range of perturbations. Default: 0.6. | |||||
c1 (Union[int, float]): Weight coefficient. Default: 2. | |||||
c2 (Union[int, float]): Weight coefficient. Default: 2. | |||||
c (Union[int, float]): Weight of perturbation loss. Default: 2. | |||||
pop_size (int): The number of particles, which should be greater | pop_size (int): The number of particles, which should be greater | ||||
than zero. Default: 6. | than zero. Default: 6. | ||||
t_max (int): The maximum round of iteration for each adversarial example, | t_max (int): The maximum round of iteration for each adversarial example, | ||||
which should be greater than zero. Default: 1000. | which should be greater than zero. Default: 1000. | ||||
pm (float): The probability of mutations. Default: 0.5. | |||||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min, | |||||
pm (Union[int, float]): The probability of mutations. Default: 0.5. | |||||
bounds (Union[list, tuple, None]): Upper and lower bounds of data. In form of (clip_min, | |||||
clip_max). Default: None. | clip_max). Default: None. | ||||
targeted (bool): If True, turns on the targeted attack. If False, | targeted (bool): If True, turns on the targeted attack. If False, | ||||
turns on untargeted attack. It should be noted that only untargeted attack | turns on untargeted attack. It should be noted that only untargeted attack | ||||
@@ -60,8 +60,8 @@ class PSOAttack(Attack): | |||||
input labels are one-hot-encoded. Default: True. | input labels are one-hot-encoded. Default: True. | ||||
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | ||||
default: 'classification'. | default: 'classification'. | ||||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||||
model_type='detection'. Default: 0.3. | |||||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | |||||
specifically for model_type='detection'. Default: 0.3. | |||||
Examples: | Examples: | ||||
>>> attack = PSOAttack(model) | >>> attack = PSOAttack(model) | ||||
@@ -161,7 +161,7 @@ class PSOAttack(Attack): | |||||
pixel_deep = self._bounds[1] - self._bounds[0] | pixel_deep = self._bounds[1] - self._bounds[0] | ||||
cur_pop = check_numpy_param('cur_pop', cur_pop) | cur_pop = check_numpy_param('cur_pop', cur_pop) | ||||
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep | perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep | ||||
mutated_pop = perturb_noise + cur_pop | |||||
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop | |||||
if self._model_type == 'classification': | if self._model_type == 'classification': | ||||
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), | mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), | ||||
cur_pop + self._per_bounds*np.abs(cur_pop)), | cur_pop + self._per_bounds*np.abs(cur_pop)), | ||||
@@ -194,7 +194,14 @@ class PSOAttack(Attack): | |||||
if self._model_type == 'classification': | if self._model_type == 'classification': | ||||
inputs, labels = check_pair_numpy_param('inputs', inputs, | inputs, labels = check_pair_numpy_param('inputs', inputs, | ||||
'labels', labels) | 'labels', labels) | ||||
if not self._sparse: | |||||
if self._sparse: | |||||
label_squ = np.squeeze(labels) | |||||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||||
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ | |||||
"got its shape as {}.".format(labels.shape) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError | |||||
else: | |||||
labels = np.argmax(labels, axis=1) | labels = np.argmax(labels, axis=1) | ||||
images = inputs | images = inputs | ||||
elif self._model_type == 'detection': | elif self._model_type == 'detection': | ||||
@@ -302,6 +302,8 @@ def check_detection_inputs(inputs, labels): | |||||
raise ValueError(msg) | raise ValueError(msg) | ||||
else: | else: | ||||
check_numpy_param('inputs', inputs) | check_numpy_param('inputs', inputs) | ||||
images = inputs | |||||
auxiliary_inputs = () | |||||
check_param_type('labels', labels, tuple) | check_param_type('labels', labels, tuple) | ||||
if len(labels) != 2: | if len(labels) != 2: | ||||
@@ -24,8 +24,6 @@ from mindspore.nn import Cell | |||||
from mindarmour import BlackModel | from mindarmour import BlackModel | ||||
from mindarmour.adv_robustness.attacks import GeneticAttack | from mindarmour.adv_robustness.attacks import GeneticAttack | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
# for user | # for user | ||||
class ModelToBeAttacked(BlackModel): | class ModelToBeAttacked(BlackModel): | ||||
@@ -41,6 +39,28 @@ class ModelToBeAttacked(BlackModel): | |||||
return result.asnumpy() | return result.asnumpy() | ||||
class DetectionModel(BlackModel): | |||||
"""model to be attack""" | |||||
def predict(self, inputs): | |||||
"""predict""" | |||||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||||
if len(inputs.shape) == 3: | |||||
inputs_num = 1 | |||||
else: | |||||
inputs_num = inputs.shape[0] | |||||
box_and_confi = [] | |||||
pred_labels = [] | |||||
gt_number = np.random.randint(1, 128) | |||||
for _ in range(inputs_num): | |||||
boxes_i = np.random.random((gt_number, 5)) | |||||
labels_i = np.random.randint(0, 10, gt_number) | |||||
box_and_confi.append(boxes_i) | |||||
pred_labels.append(labels_i) | |||||
return np.array(box_and_confi), np.array(pred_labels) | |||||
class SimpleNet(Cell): | class SimpleNet(Cell): | ||||
""" | """ | ||||
Construct the network of target model. | Construct the network of target model. | ||||
@@ -76,6 +96,7 @@ def test_genetic_attack(): | |||||
""" | """ | ||||
Genetic_Attack test | Genetic_Attack test | ||||
""" | """ | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
batch_size = 6 | batch_size = 6 | ||||
net = SimpleNet() | net = SimpleNet() | ||||
@@ -98,6 +119,7 @@ def test_genetic_attack(): | |||||
@pytest.mark.env_card | @pytest.mark.env_card | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_supplement(): | def test_supplement(): | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
batch_size = 6 | batch_size = 6 | ||||
net = SimpleNet() | net = SimpleNet() | ||||
@@ -123,6 +145,7 @@ def test_supplement(): | |||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_value_error(): | def test_value_error(): | ||||
"""test that exception is raised for invalid labels""" | """test that exception is raised for invalid labels""" | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
batch_size = 6 | batch_size = 6 | ||||
net = SimpleNet() | net = SimpleNet() | ||||
@@ -140,3 +163,29 @@ def test_value_error(): | |||||
# raise error | # raise error | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
assert attack.generate(inputs, labels) | assert attack.generate(inputs, labels) | ||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_genetic_attack_detection_cpu(): | |||||
""" | |||||
Genetic_Attack test | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
batch_size = 2 | |||||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||||
model = DetectionModel() | |||||
attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05, | |||||
per_bounds=0.1, step_size=0.25, temp=0.1, | |||||
sparse=False, max_steps=50) | |||||
# generate adversarial samples | |||||
adv_imgs = [] | |||||
for i in range(batch_size): | |||||
img_data = np.expand_dims(inputs[i], axis=0) | |||||
pre_gt_boxes, pre_gt_labels = model.predict(inputs) | |||||
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) | |||||
adv_imgs.append(adv_img) | |||||
assert np.any(inputs != np.array(adv_imgs)) |
@@ -43,6 +43,28 @@ class ModelToBeAttacked(BlackModel): | |||||
return result.asnumpy() | return result.asnumpy() | ||||
class DetectionModel(BlackModel): | |||||
"""model to be attack""" | |||||
def predict(self, inputs): | |||||
"""predict""" | |||||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||||
if len(inputs.shape) == 3: | |||||
inputs_num = 1 | |||||
else: | |||||
inputs_num = inputs.shape[0] | |||||
box_and_confi = [] | |||||
pred_labels = [] | |||||
gt_number = np.random.randint(1, 128) | |||||
for _ in range(inputs_num): | |||||
boxes_i = np.random.random((gt_number, 5)) | |||||
labels_i = np.random.randint(0, 10, gt_number) | |||||
box_and_confi.append(boxes_i) | |||||
pred_labels.append(labels_i) | |||||
return np.array(box_and_confi), np.array(pred_labels) | |||||
class SimpleNet(Cell): | class SimpleNet(Cell): | ||||
""" | """ | ||||
Construct the network of target model. | Construct the network of target model. | ||||
@@ -167,3 +189,27 @@ def test_pso_attack_cpu(): | |||||
attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | ||||
_, adv_data, _ = attack.generate(inputs, labels) | _, adv_data, _ = attack.generate(inputs, labels) | ||||
assert np.any(inputs != adv_data) | assert np.any(inputs != adv_data) | ||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_pso_attack_detection_cpu(): | |||||
""" | |||||
PSO_Attack test | |||||
""" | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
batch_size = 2 | |||||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||||
model = DetectionModel() | |||||
attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5) | |||||
# generate adversarial samples | |||||
adv_imgs = [] | |||||
for i in range(batch_size): | |||||
img_data = np.expand_dims(inputs[i], axis=0) | |||||
pre_gt_boxes, pre_gt_labels = model.predict(inputs) | |||||
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) | |||||
adv_imgs.append(adv_img) | |||||
assert np.any(inputs != np.array(adv_imgs)) |