Browse Source

添加 'gpu_new/test_inference.py'

test_v20221116
wjtest1215 1 year ago
parent
commit
47d89be459
1 changed files with 80 additions and 0 deletions
  1. +80
    -0
      gpu_new/test_inference.py

+ 80
- 0
gpu_new/test_inference.py View File

@@ -0,0 +1,80 @@
#!/usr/bin/python
#coding=utf-8
'''
GPU INFERENCE INSTANCE

If there are Chinese comments in the code,please add at the beginning:
#!/usr/bin/python
#coding=utf-8
Due to the adaptability of a100, please use the recommended image of the
platform with cuda 11.Then adjust the code and submit the image.
The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
In the environment, the uploaded dataset will be automatically placed in the /dataset directory.
if MnistDataset_torch.zip is selected,Then the dataset directory is /dataset/test;

The model file selected is in /model directory.
The result download path is under /result . and the Qizhi platform will provide file downloads under the /result directory.

本例中的镜像是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
选择的数据集被放置在/dataset目录
选择的模型文件放置在/model目录
输出结果路径是/result目录

!!!注意:目前推理的资源环境不支持联网,所以镜像无法使用公网镜像,镜像必须先提交到启智平台;推理的数据集也需要先上传到启智平台

'''


import numpy as np
import torch
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import os
import argparse
from model import Model



# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
#获取模型文件名称
parser.add_argument('--modelname', help='model name')



if __name__ == '__main__':
args, unknown = parser.parse_known_args()
print('cuda is available:{}'.format(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_dataset = mnist.MNIST(root='/dataset/test', train=False, transform=ToTensor(),
download=False)
test_loader = DataLoader(test_dataset, batch_size=256)
#如果文件名确定,model_path可以直接写死
model_path = '/model/'+args.modelname

model = Model().to(device)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model'])

model.eval()

correct = 0
_sum = 0

for idx, (test_x, test_label) in enumerate(test_loader):
test_x = test_x
test_label = test_label
predict_y = model(test_x.to(device).float()).detach()
predict_ys = np.argmax(predict_y.cpu(), axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
print('accuracy: {:.2f}'.format(correct / _sum))
#结果写入/result
filename = 'result.txt'
file_path = os.path.join('/result', filename)
with open(file_path, 'w') as file:
file.write('accuracy: {:.2f}'.format(correct / _sum))

Loading…
Cancel
Save