From: @jxlang910 Reviewed-by: @pkuliuliu,@pkuliuliu,@liu_luobin Signed-off-by: @pkuliuliu,@pkuliuliutags/v1.1.0
@@ -105,7 +105,7 @@ class ModelToBeAttacked(BlackModel): | |||||
query = np.expand_dims(inputs[i].astype(np.float32), axis=0) | query = np.expand_dims(inputs[i].astype(np.float32), axis=0) | ||||
result = self._network(Tensor(query)).asnumpy() | result = self._network(Tensor(query)).asnumpy() | ||||
det_num = len(self._detector.get_detected_queries()) | det_num = len(self._detector.get_detected_queries()) | ||||
self._detector.detect([query]) | |||||
self._detector.detect(np.array([query])) | |||||
new_det_num = len(self._detector.get_detected_queries()) | new_det_num = len(self._detector.get_detected_queries()) | ||||
# If attack query detected, return random predict result | # If attack query detected, return random predict result | ||||
if new_det_num > det_num: | if new_det_num > det_num: | ||||
@@ -116,6 +116,8 @@ class ModelToBeAttacked(BlackModel): | |||||
self._detected_res.append(False) | self._detected_res.append(False) | ||||
results = np.concatenate(results) | results = np.concatenate(results) | ||||
else: | else: | ||||
if len(inputs.shape) == 3: | |||||
inputs = np.expand_dims(inputs, axis=0) | |||||
results = self._network(Tensor(inputs.astype(np.float32))).asnumpy() | results = self._network(Tensor(inputs.astype(np.float32))).asnumpy() | ||||
return results | return results | ||||
@@ -49,7 +49,13 @@ class ModelToBeAttacked(BlackModel): | |||||
""" | """ | ||||
query_num = inputs.shape[0] | query_num = inputs.shape[0] | ||||
for i in range(query_num): | for i in range(query_num): | ||||
self._queries.append(inputs[i].astype(np.float32)) | |||||
if len(inputs[i].shape) == 2: | |||||
temp = np.expand_dims(inputs[i], axis=0) | |||||
else: | |||||
temp = inputs[i] | |||||
self._queries.append(temp.astype(np.float32)) | |||||
if len(inputs.shape) == 3: | |||||
inputs = np.expand_dims(inputs, axis=0) | |||||
result = self._network(Tensor(inputs.astype(np.float32))) | result = self._network(Tensor(inputs.astype(np.float32))) | ||||
return result.asnumpy() | return result.asnumpy() | ||||
@@ -160,7 +166,7 @@ def test_similarity_detector(): | |||||
# test attack queries | # test attack queries | ||||
detector.clear_buffer() | detector.clear_buffer() | ||||
detector.detect(suspicious_queries) | |||||
detector.detect(np.array(suspicious_queries)) | |||||
LOGGER.info(TAG, 'Number of detected attack queries is : %s', | LOGGER.info(TAG, 'Number of detected attack queries is : %s', | ||||
len(detector.get_detected_queries())) | len(detector.get_detected_queries())) | ||||
LOGGER.info(TAG, 'The detected attack query indexes are : %s', | LOGGER.info(TAG, 'The detected attack query indexes are : %s', | ||||
@@ -45,7 +45,7 @@ class GeneticAttack(Attack): | |||||
default: 'classification'. | default: 'classification'. | ||||
targeted (bool): If True, turns on the targeted attack. If False, | targeted (bool): If True, turns on the targeted attack. If False, | ||||
turns on untargeted attack. It should be noted that only untargeted attack | turns on untargeted attack. It should be noted that only untargeted attack | ||||
is supproted for model_type='detection', Default: False. | |||||
is supproted for model_type='detection', Default: True. | |||||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | ||||
specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. | ||||
pop_size (int): The number of particles, which should be greater than | pop_size (int): The number of particles, which should be greater than | ||||
@@ -141,10 +141,12 @@ class GeneticAttack(Attack): | |||||
labels (or ground_truth labels). | labels (or ground_truth labels). | ||||
Args: | Args: | ||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) | |||||
or only one array if model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels | |||||
should be (gt_boxes, gt_labels) if model_type='detection'. | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
Returns: | Returns: | ||||
- numpy.ndarray, bool values for each attack result. | - numpy.ndarray, bool values for each attack result. | ||||
@@ -175,10 +175,12 @@ class PSOAttack(Attack): | |||||
labels (or ground_truth labels). | labels (or ground_truth labels). | ||||
Args: | Args: | ||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) | |||||
or only one array if model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground_truth labels. The format of labels | |||||
should be (gt_boxes, gt_labels) if model_type='detection'. | |||||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if | |||||
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if | |||||
model_type='detection'. | |||||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should | |||||
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) | |||||
if model_type='detection'. | |||||
Returns: | Returns: | ||||
- numpy.ndarray, bool values for each attack result. | - numpy.ndarray, bool values for each attack result. | ||||