diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index f97b662..6c8bc43 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -183,12 +183,12 @@ class Attack: x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) x_shape = best_position.shape reduction_iters = 10000 # recover 0.01% each step - _, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) + _, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) for _ in range(reduction_iters): diff = x_ori - best_position res = 0.5*diff*(np.random.random(x_shape) < 0.0001) best_position += res - _, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) + _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) q_times += 1 if correct_num > original_num: best_position -= res diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 949fa51..92698eb 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -313,15 +313,15 @@ def check_detection_inputs(inputs, labels): has_labels = False for item in labels: check_numpy_param('item', item) - if len(item.shape) == 3 and item.shape[2] == 5: + if len(item.shape) == 3: gt_boxes = item has_boxes = True elif len(item.shape) == 2: gt_labels = item has_labels = True if (not has_boxes) or (not has_labels): - msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \ - 'But got {} and {}.'.format(labels[0].shape, labels[1].shape) + msg = 'The shape of boxes array should be (N, M, 5) or (N, M, 4), and the shape of ground-truth' \ + 'labels array should be (N, M). But got {} and {}.'.format(labels[0].shape, labels[1].shape) LOGGER.error(TAG, msg) raise ValueError(msg) return images, auxiliary_inputs, gt_boxes, gt_labels