diff --git a/examples/community/face_adversarial_attack/README.md b/examples/community/face_adversarial_attack/README.md new file mode 100644 index 0000000..b49a3ce --- /dev/null +++ b/examples/community/face_adversarial_attack/README.md @@ -0,0 +1,119 @@ + +# 人脸识别物理对抗攻击 + +## 描述 + +本项目是基于MindSpore框架对人脸识别模型的物理对抗攻击,通过生成对抗口罩,使人脸佩戴后实现有目标攻击和非目标攻击。 + +## 模型结构 + +采用华为MindSpore官方训练的FaceRecognition模型 +https://www.mindspore.cn/resources/hub/details?MindSpore/1.7/facerecognition_ms1mv2 + +## 环境要求 + +mindspore>=1.7,硬件平台为GPU。 + +## 脚本说明 + +```markdown +├── readme.md +├── photos +│ ├── adv_input //对抗图像 +│ ├── input //输入图像 +│ └── target //目标图像 +├── outputs //训练后的图像 +├── adversarial_attack.py //训练脚本 +│── example_non_target_attack.py //无目标攻击训练 +│── example_target_attack.py //有目标攻击训练 +│── loss_design.py //训练优化设置 +└── test.py //评估攻击效果 +``` + +## 模型调用 + +方法一: + +```python +#基于mindspore_hub库调用FaceRecognition模型 +import mindspore_hub as mshub +from mindspore import context +def get_model(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0) + model = "mindspore/1.7/facerecognition_ms1mv2" + network = mshub.load(model) + network.set_train(False) + return network +``` + +方法二: + +```text +利用MindSpore代码仓中的https://gitee.com/mindspore/models/blob/master/research/cv/FaceRecognition/eval.py的get_model函数加载模型 +``` + +## 训练过程 + +有目标攻击: + +```shell +cd face_adversarial_attack/ +python example_target_attack.py +``` + +非目标攻击: + +```shell +cd face_adversarial_attack/ +python example_non_target_attack.py +``` + +## 默认训练参数 + +optimizer=adam, learning rate=0.01, weight_decay=0.0001, epoch=2000 + +## 评估过程 + +评估方法一: + +```shell +adversarial_attack.FaceAdversarialAttack.test_non_target_attack() +adversarial_attack.FaceAdversarialAttack.test_target_attack() +``` + +评估方法二: + +```shell +cd face_adversarial_attack/ +python test.py + ``` + +## 实验结果 + +有目标攻击: + +```text +input_label: 60 +target_label: 345 +The confidence of the input image on the input label: 26.67 +The confidence of the input image on the target label: 0.95 +================================ +adversarial_label: 345 +The confidence of the adversarial sample on the correct label: 1.82 +The confidence of the adversarial sample on the target label: 10.96 +input_label: 60, target_label: 345, adversarial_label: 345 + +photos中是有目标攻击的实验结果 +``` + +非目标攻击: + +```text +input_label: 60 +The confidence of the input image on the input label: 25.16 +================================ +adversarial_label: 251 +The confidence of the adversarial sample on the correct label: 9.52 +The confidence of the adversarial sample on the adversarial label: 60.96 +input_label: 60, adversarial_label: 251 +``` diff --git a/examples/community/face_adversarial_attack/adversarial_attack.py b/examples/community/face_adversarial_attack/adversarial_attack.py new file mode 100644 index 0000000..7c1641b --- /dev/null +++ b/examples/community/face_adversarial_attack/adversarial_attack.py @@ -0,0 +1,275 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train set""" +import os +import re +import numpy as np +import face_recognition as fr +import face_recognition_models as frm +import dlib +from PIL import Image, ImageDraw +import mindspore +import mindspore.dataset.vision.py_transforms as P +from mindspore.dataset.vision.py_transforms import ToPIL as ToPILImage +from mindspore.dataset.vision.py_transforms import ToTensor +from mindspore import Parameter, ops, nn, Tensor +from loss_design import MyTrainOneStepCell, MyWithLossCellTargetAttack, \ + MyWithLossCellNonTargetAttack, FaceLossTargetAttack, FaceLossNoTargetAttack + + +class FaceAdversarialAttack(): + """ + Class used to create adversarial facial recognition attacks. + + Args: + input_img (numpy.ndarray): The input image. + target_img (numpy.ndarray): The target image. + seed (int): optional Sets custom seed for reproducibility. Default is generated randomly. + net (mindspore.Model): face recognition model. + """ + def __init__(self, input_img, target_img, net, seed=None): + if seed is not None: + np.random.seed(seed) + self.mean = Tensor([0.485, 0.456, 0.406]) + self.std = Tensor([0.229, 0.224, 0.225]) + self.expand_dims = mindspore.ops.ExpandDims() + self.imageize = ToPILImage() + self.tensorize = ToTensor() + self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.resnet = net + self.input_tensor = Tensor(self.normalize(self.tensorize(input_img))) + self.target_tensor = Tensor(self.normalize(self.tensorize(target_img))) + self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) + self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0)) + self.adversarial_emb = None + self.mask_tensor = create_mask(input_img) + self.ref = self.mask_tensor + self.pm = Parameter(self.mask_tensor) + self.opt = nn.Adam([self.pm], learning_rate=0.01, weight_decay=0.0001) + + def train(self, attack_method): + """ + Optimized adversarial image. + + Args: + attack_method (String) : Including target attack and non_target attack. + + Returns: + Tensor, adversarial image. + Tensor, mask image. + """ + + if attack_method == "non_target_attack": + loss = FaceLossNoTargetAttack() + net_with_criterion = MyWithLossCellNonTargetAttack(self.resnet, loss, self.input_tensor) + if attack_method == "target_attack": + loss = FaceLossTargetAttack(self.target_emb) + net_with_criterion = MyWithLossCellTargetAttack(self.resnet, loss, self.input_tensor) + + train_net = MyTrainOneStepCell(net_with_criterion, self.opt) + + for i in range(2000): + + self.mask_tensor = Tensor(self.pm) + + loss = train_net(self.mask_tensor) + + print("epoch %d ,loss: %f \n " % (i, loss.asnumpy().item())) + + self.mask_tensor = ops.clip_by_value( + self.mask_tensor, Tensor(0, mindspore.float32), Tensor(1, mindspore.float32)) + + adversarial_tensor = apply( + self.input_tensor, + (self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], + self.ref) + + adversarial_tensor = self._reverse_norm(adversarial_tensor) + processed_input_tensor = self._reverse_norm(self.input_tensor) + processed_target_tensor = self._reverse_norm(self.target_tensor) + + return { + "adversarial_tensor": adversarial_tensor, + "mask_tensor": self.mask_tensor, + "processed_input_tensor": processed_input_tensor, + "processed_target_tensor": processed_target_tensor + } + + def test_target_attack(self): + """ + The model is used to test the recognition ability of adversarial images under target attack. + """ + + adversarial_tensor = apply( + self.input_tensor, + (self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], + self.ref) + + self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0)) + self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) + self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0)) + + adversarial_index = np.argmax(self.adversarial_emb.asnumpy()) + target_index = np.argmax(self.target_emb.asnumpy()) + input_index = np.argmax(self.input_emb.asnumpy()) + + print("input_label:", input_index) + print("target_label:", target_index) + print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index]) + print("The confidence of the input image on the target label:", self.input_emb.asnumpy()[0][target_index]) + print("================================") + print("adversarial_label:", adversarial_index) + print("The confidence of the adversarial sample on the correct label:", + self.adversarial_emb.asnumpy()[0][input_index]) + print("The confidence of the adversarial sample on the target label:", + self.adversarial_emb.asnumpy()[0][target_index]) + print("input_label: %d, target_label: %d, adversarial_label: %d" + % (input_index, target_index, adversarial_index)) + + def test_non_target_attack(self): + """ + The model is used to test the recognition ability of adversarial images under non_target attack. + """ + + adversarial_tensor = apply( + self.input_tensor, + (self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], + self.ref) + + self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0)) + self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) + + adversarial_index = np.argmax(self.adversarial_emb.asnumpy()) + input_index = np.argmax(self.input_emb.asnumpy()) + + print("input_label:", input_index) + print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index]) + print("================================") + print("adversarial_label:", adversarial_index) + print("The confidence of the adversarial sample on the correct label:", + self.adversarial_emb.asnumpy()[0][input_index]) + print("The confidence of the adversarial sample on the adversarial label:", + self.adversarial_emb.asnumpy()[0][adversarial_index]) + print( + "input_label: %d, adversarial_label: %d" % (input_index, adversarial_index)) + + def _reverse_norm(self, image_tensor): + """ + Reverses normalization for a given image_tensor. + + Args: + image_tensor (Tensor): Tensor. + + Returns: + Tensor, image. + """ + tensor = image_tensor * self.std[:, None, None] + self.mean[:, None, None] + return tensor + + +def apply(image_tensor, mask_tensor, reference_tensor): + """ + Apply a mask over an image. + + Args: + image_tensor (Tensor): Canvas to be used to apply mask on. + mask_tensor (Tensor): Mask to apply over the image. + reference_tensor (Tensor): Used to reference mask boundaries + + Returns: + Tensor, image. + """ + tensor = mindspore.numpy.where((reference_tensor == 0), image_tensor, mask_tensor) + + return tensor + + +def create_mask(face_image): + """ + Create mask image. + + Args: + face_image (PIL.Image): image of a detected face. + + Returns: + mask_tensor : a mask image. + """ + + mask = Image.new('RGB', face_image.size, color=(0, 0, 0)) + d = ImageDraw.Draw(mask) + landmarks = fr.face_landmarks(np.array(face_image)) + area = [landmark + for landmark in landmarks[0]['chin'] + if landmark[1] > max(landmarks[0]['nose_tip'])[1]] + area.append(landmarks[0]['nose_bridge'][1]) + d.polygon(area, fill=(255, 255, 255)) + mask = np.array(mask) + mask = mask.astype(np.float32) + + for i in range(mask.shape[0]): + for j in range(mask.shape[1]): + for k in range(mask.shape[2]): + if mask[i][j][k] == 255.: + mask[i][j][k] = 0.5 + else: + mask[i][j][k] = 0 + + mask_tensor = Tensor(mask) + mask_tensor = mask_tensor.swapaxes(0, 2).swapaxes(1, 2) + mask_tensor.requires_grad = True + return mask_tensor + + +def detect_face(image): + """ + Face detection and alignment process using dlib library. + + Args: + image (numpy.ndarray): image file location. + + Returns: + face_image : Resized face image. + """ + dlib_detector = dlib.get_frontal_face_detector() + dlib_shape_predictor = dlib.shape_predictor(frm.pose_predictor_model_location()) + dlib_image = dlib.load_rgb_image(image) + detections = dlib_detector(dlib_image, 1) + + dlib_faces = dlib.full_object_detections() + for det in detections: + dlib_faces.append(dlib_shape_predictor(dlib_image, det)) + face_image = Image.fromarray(dlib.get_face_chip(dlib_image, dlib_faces[0], size=112)) + + return face_image + + +def load_data(data): + """ + An auxiliary function that loads image data. + + Args: + data (String): The path to the given data. + + Returns: + list : Resize list of face images. + """ + image_files = [f for f in os.listdir(data) if re.search(r'.*\.(jpe?g|png)', f)] + image_files_locs = [os.path.join(data, f) for f in image_files] + + image_list = [] + for img in image_files_locs: + image_list.append(detect_face(img)) + + return image_list diff --git a/examples/community/face_adversarial_attack/example_non_target_attack.py b/examples/community/face_adversarial_attack/example_non_target_attack.py new file mode 100644 index 0000000..cc86a69 --- /dev/null +++ b/examples/community/face_adversarial_attack/example_non_target_attack.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""non target attack""" +import numpy as np +import matplotlib.image as mp +from mindspore import context +import adversarial_attack +from FaceRecognition.eval import get_model + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +if __name__ == '__main__': + + inputs = adversarial_attack.load_data('photos/input/') + targets = adversarial_attack.load_data('photos/target/') + + net = get_model() + adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net) + ATTACK_METHOD = "non_target_attack" + + tensor_dict = adversarial.train(attack_method=ATTACK_METHOD) + + mp.imsave('./outputs/adversarial_example.jpg', + np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/mask.jpg', + np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/input_image.jpg', + np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/target_image.jpg', + np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0))) + + adversarial.test_non_target_attack() diff --git a/examples/community/face_adversarial_attack/example_target_attack.py b/examples/community/face_adversarial_attack/example_target_attack.py new file mode 100644 index 0000000..083e0ab --- /dev/null +++ b/examples/community/face_adversarial_attack/example_target_attack.py @@ -0,0 +1,46 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""target attack""" +import numpy as np +import matplotlib.image as mp +from mindspore import context +import adversarial_attack +from FaceRecognition.eval import get_model + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +if __name__ == '__main__': + + inputs = adversarial_attack.load_data('photos/input/') + targets = adversarial_attack.load_data('photos/target/') + + net = get_model() + adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net) + + ATTACK_METHOD = "target_attack" + + tensor_dict = adversarial.train(attack_method=ATTACK_METHOD) + + mp.imsave('./outputs/adversarial_example.jpg', + np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/mask.jpg', + np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/input_image.jpg', + np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0))) + mp.imsave('./outputs/target_image.jpg', + np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0))) + + adversarial.test_target_attack() diff --git a/examples/community/face_adversarial_attack/loss_design.py b/examples/community/face_adversarial_attack/loss_design.py new file mode 100644 index 0000000..65dc5dd --- /dev/null +++ b/examples/community/face_adversarial_attack/loss_design.py @@ -0,0 +1,154 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""optimization Settings""" +import mindspore +from mindspore import ops, nn, Tensor +from mindspore.dataset.vision.py_transforms import ToTensor +import mindspore.dataset.vision.py_transforms as P + + +class MyTrainOneStepCell(nn.TrainOneStepCell): + """ + Encapsulation class of network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + + def __init__(self, network, optimizer, sens=1.0): + super(MyTrainOneStepCell, self).__init__(network, optimizer, sens) + self.grad = ops.composite.GradOperation(get_all=True, sens_param=False) + + def construct(self, *inputs): + """Defines the computation performed.""" + loss = self.network(*inputs) + grads = self.grad(self.network)(*inputs) + self.optimizer(grads) + return loss + + +class MyWithLossCellTargetAttack(nn.Cell): + """The loss function defined by the target attack""" + def __init__(self, net, loss_fn, input_tensor): + super(MyWithLossCellTargetAttack, self).__init__(auto_prefix=False) + self.net = net + self._loss_fn = loss_fn + self.std = Tensor([0.229, 0.224, 0.225]) + self.mean = Tensor([0.485, 0.456, 0.406]) + self.expand_dims = mindspore.ops.ExpandDims() + self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.tensorize = ToTensor() + self.input_tensor = input_tensor + self.input_emb = self.net(self.expand_dims(self.input_tensor, 0)) + + @property + def backbone_network(self): + return self.net + + def construct(self, mask_tensor): + ref = mask_tensor + adversarial_tensor = mindspore.numpy.where( + (ref == 0), + self.input_tensor, + (mask_tensor - self.mean[:, None, None]) / self.std[:, None, None]) + adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0)) + loss = self._loss_fn(adversarial_emb) + return loss + + +class MyWithLossCellNonTargetAttack(nn.Cell): + """The loss function defined by the non target attack""" + def __init__(self, net, loss_fn, input_tensor): + super(MyWithLossCellNonTargetAttack, self).__init__(auto_prefix=False) + self.net = net + self._loss_fn = loss_fn + self.std = Tensor([0.229, 0.224, 0.225]) + self.mean = Tensor([0.485, 0.456, 0.406]) + self.expand_dims = mindspore.ops.ExpandDims() + self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.tensorize = ToTensor() + self.input_tensor = input_tensor + self.input_emb = self.net(self.expand_dims(self.input_tensor, 0)) + + @property + def backbone_network(self): + return self.net + + def construct(self, mask_tensor): + ref = mask_tensor + adversarial_tensor = mindspore.numpy.where( + (ref == 0), + self.input_tensor, + (mask_tensor - self.mean[:, None, None]) / self.std[:, None, None]) + adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0)) + loss = self._loss_fn(adversarial_emb, self.input_emb) + return loss + + +class FaceLossTargetAttack(nn.Cell): + """The loss function of the target attack""" + + def __init__(self, target_emb): + super(FaceLossTargetAttack, self).__init__() + self.uniformreal = ops.UniformReal(seed=2) + self.sum = ops.ReduceSum(keep_dims=False) + self.norm = nn.Norm(keep_dims=True) + self.zeroslike = ops.ZerosLike() + self.concat_op1 = ops.Concat(1) + self.concat_op2 = ops.Concat(2) + self.pow = ops.Pow() + self.reduce_sum = ops.operations.ReduceSum() + self.target_emb = target_emb + self.abs = ops.Abs() + self.reduce_mean = ops.ReduceMean() + + def construct(self, adversarial_emb): + prod_sum = self.reduce_sum(adversarial_emb * self.target_emb, (1,)) + square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,)) + square2 = self.reduce_sum(ops.functional.square(self.target_emb), (1,)) + denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2) + loss = -(prod_sum / denom) + return loss + + +class FaceLossNoTargetAttack(nn.Cell): + """The loss function of the non-target attack""" + + def __init__(self): + """Initialization""" + super(FaceLossNoTargetAttack, self).__init__() + self.uniformreal = ops.UniformReal(seed=2) + self.sum = ops.ReduceSum(keep_dims=False) + self.norm = nn.Norm(keep_dims=True) + self.zeroslike = ops.ZerosLike() + self.concat_op1 = ops.Concat(1) + self.concat_op2 = ops.Concat(2) + self.pow = ops.Pow() + self.reduce_sum = ops.operations.ReduceSum() + self.abs = ops.Abs() + self.reduce_mean = ops.ReduceMean() + + def construct(self, adversarial_emb, input_emb): + prod_sum = self.reduce_sum(adversarial_emb * input_emb, (1,)) + square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,)) + square2 = self.reduce_sum(ops.functional.square(input_emb), (1,)) + denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2) + loss = prod_sum / denom + return loss diff --git a/examples/community/face_adversarial_attack/photos/adv_input/adv.png b/examples/community/face_adversarial_attack/photos/adv_input/adv.png new file mode 100644 index 0000000..c9098a8 Binary files /dev/null and b/examples/community/face_adversarial_attack/photos/adv_input/adv.png differ diff --git a/examples/community/face_adversarial_attack/photos/input/input1.jpg b/examples/community/face_adversarial_attack/photos/input/input1.jpg new file mode 100644 index 0000000..3cd6d40 Binary files /dev/null and b/examples/community/face_adversarial_attack/photos/input/input1.jpg differ diff --git a/examples/community/face_adversarial_attack/photos/target/target1.jpg b/examples/community/face_adversarial_attack/photos/target/target1.jpg new file mode 100644 index 0000000..aca7d19 Binary files /dev/null and b/examples/community/face_adversarial_attack/photos/target/target1.jpg differ diff --git a/examples/community/face_adversarial_attack/test.py b/examples/community/face_adversarial_attack/test.py new file mode 100644 index 0000000..c5da3e1 --- /dev/null +++ b/examples/community/face_adversarial_attack/test.py @@ -0,0 +1,59 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test""" +import numpy as np +from mindspore import context, Tensor +import mindspore +from mindspore.dataset.vision.py_transforms import ToTensor +import mindspore.dataset.vision.py_transforms as P +from FaceRecognition.eval import get_model +import adversarial_attack + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +if __name__ == '__main__': + + image = adversarial_attack.load_data('photos/adv_input/') + inputs = adversarial_attack.load_data('photos/input/') + targets = adversarial_attack.load_data('photos/target/') + + tensorize = ToTensor() + normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + expand_dims = mindspore.ops.ExpandDims() + mean = Tensor([0.485, 0.456, 0.406]) + std = Tensor([0.229, 0.224, 0.225]) + + resnet = get_model() + + adv = Tensor(normalize(tensorize(image[0]))) + input_tensor = Tensor(normalize(tensorize(inputs[0]))) + target_tensor = Tensor(normalize(tensorize(targets[0]))) + + adversarial_emb = resnet(expand_dims(adv, 0)) + input_emb = resnet(expand_dims(input_tensor, 0)) + target_emb = resnet(expand_dims(target_tensor, 0)) + + adversarial_index = np.argmax(adversarial_emb.asnumpy()) + target_index = np.argmax(target_emb.asnumpy()) + input_index = np.argmax(input_emb.asnumpy()) + + print("input_label:", input_index) + print("The confidence of the input image on the input label:", input_emb.asnumpy()[0][input_index]) + print("================================") + print("adversarial_label:", adversarial_index) + print("The confidence of the adversarial sample on the correct label:", adversarial_emb.asnumpy()[0][input_index]) + print("The confidence of the adversarial sample on the adversarial label:", + adversarial_emb.asnumpy()[0][adversarial_index]) + print("input_label:%d, adversarial_label:%d" % (input_index, adversarial_index))