#!/usr/bin/python #coding=utf-8 ''' If there are Chinese comments in the code,please add at the beginning: #!/usr/bin/python #coding=utf-8 示例选用的数据集是MnistDataset_torch.zip 数据集结构是: MnistDataset_torch.zip ├── test └── train 预训练模型文件夹结构是: Torch_MNIST_Example_Model ├── mnist_epoch1.pkl ''' from model import Model import numpy as np import torch from torchvision.datasets import mnist from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse import os #导入c2net包 from c2net.context import prepare # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train') parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') # 参数声明 WORKERS = 0 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Model().to(device) optimizer = SGD(model.parameters(), lr=1e-1) cost = CrossEntropyLoss() # 模型测试 def test(model, test_loader, data_length): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): x, y = data x = x.to(device) y = y.to(device) y_hat = model(x) test_loss += cost(y_hat, y).item() pred = y_hat.max(1, keepdim=True)[1] correct += pred.eq(y.view_as(pred)).sum().item() test_loss /= (i+1) # 结果写入输出文件夹 print('accuracy: {:.2f}'.format(correct / data_length)) filename = 'result.txt' file_path = os.path.join('/tmp/output', filename) with open(file_path, 'w') as file: file.write('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, data_length, 100. * correct / data_length)) if __name__ == '__main__': args, unknown = parser.parse_known_args() #初始化导入数据集和预训练模型到容器内 c2net_context = prepare() #获取数据集路径 MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" #获取预训练模型路径 Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model" #log output print('cuda is available:{}'.format(torch.cuda.is_available())) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = args.batch_size epochs = args.epoch_size test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False) test_loader = DataLoader(test_dataset, batch_size=batch_size) model = Model().to(device) checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1.pkl") model.load_state_dict(checkpoint['model']) test(model,test_loader,len(test_dataset))