From 47d89be4591118888f463219fa501af1846f70dd Mon Sep 17 00:00:00 2001 From: wjtest1215 Date: Wed, 26 Oct 2022 17:29:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20'gpu=5Fnew/test=5Finferenc?= =?UTF-8?q?e.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gpu_new/test_inference.py | 80 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 gpu_new/test_inference.py diff --git a/gpu_new/test_inference.py b/gpu_new/test_inference.py new file mode 100644 index 0000000..68952e3 --- /dev/null +++ b/gpu_new/test_inference.py @@ -0,0 +1,80 @@ +#!/usr/bin/python +#coding=utf-8 +''' +GPU INFERENCE INSTANCE + +If there are Chinese comments in the code,please add at the beginning: +#!/usr/bin/python +#coding=utf-8 +Due to the adaptability of a100, please use the recommended image of the +platform with cuda 11.Then adjust the code and submit the image. +The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191 +In the environment, the uploaded dataset will be automatically placed in the /dataset directory. +if MnistDataset_torch.zip is selected,Then the dataset directory is /dataset/test; + +The model file selected is in /model directory. +The result download path is under /result . and the Qizhi platform will provide file downloads under the /result directory. + +本例中的镜像是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191 +选择的数据集被放置在/dataset目录 +选择的模型文件放置在/model目录 +输出结果路径是/result目录 + +!!!注意:目前推理的资源环境不支持联网,所以镜像无法使用公网镜像,镜像必须先提交到启智平台;推理的数据集也需要先上传到启智平台 + +''' + + +import numpy as np +import torch +from torchvision.datasets import mnist +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor +import os +import argparse +from model import Model + + + +# Training settings +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +#获取模型文件名称 +parser.add_argument('--modelname', help='model name') + + + +if __name__ == '__main__': + args, unknown = parser.parse_known_args() + print('cuda is available:{}'.format(torch.cuda.is_available())) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + test_dataset = mnist.MNIST(root='/dataset/test', train=False, transform=ToTensor(), + download=False) + test_loader = DataLoader(test_dataset, batch_size=256) + #如果文件名确定,model_path可以直接写死 + model_path = '/model/'+args.modelname + + model = Model().to(device) + checkpoint = torch.load(model_path) + model.load_state_dict(checkpoint['model']) + + model.eval() + + correct = 0 + _sum = 0 + + for idx, (test_x, test_label) in enumerate(test_loader): + test_x = test_x + test_label = test_label + predict_y = model(test_x.to(device).float()).detach() + predict_ys = np.argmax(predict_y.cpu(), axis=-1) + label_np = test_label.numpy() + _ = predict_ys == test_label + correct += np.sum(_.numpy(), axis=-1) + _sum += _.shape[0] + print('accuracy: {:.2f}'.format(correct / _sum)) + #结果写入/result + filename = 'result.txt' + file_path = os.path.join('/result', filename) + with open(file_path, 'w') as file: + file.write('accuracy: {:.2f}'.format(correct / _sum)) \ No newline at end of file