You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 3.1 kB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. └── train
  12. 预训练模型文件夹结构是:
  13. Torch_MNIST_Example_Model
  14. ├── mnist_epoch1.pkl
  15. '''
  16. from model import Model
  17. import numpy as np
  18. import torch
  19. from torchvision.datasets import mnist
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from torch.utils.data import DataLoader
  23. from torchvision.transforms import ToTensor
  24. import argparse
  25. import os
  26. #导入c2net包
  27. from c2net.context import prepare
  28. # Training settings
  29. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  30. parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
  31. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  32. # 参数声明
  33. WORKERS = 0
  34. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  35. model = Model().to(device)
  36. optimizer = SGD(model.parameters(), lr=1e-1)
  37. cost = CrossEntropyLoss()
  38. # 模型测试
  39. def test(model, test_loader, data_length):
  40. model.eval()
  41. test_loss = 0
  42. correct = 0
  43. with torch.no_grad():
  44. for i, data in enumerate(test_loader, 0):
  45. x, y = data
  46. x = x.to(device)
  47. y = y.to(device)
  48. y_hat = model(x)
  49. test_loss += cost(y_hat, y).item()
  50. pred = y_hat.max(1, keepdim=True)[1]
  51. correct += pred.eq(y.view_as(pred)).sum().item()
  52. test_loss /= (i+1)
  53. # 结果写入输出文件夹
  54. print('accuracy: {:.2f}'.format(correct / data_length))
  55. filename = 'result.txt'
  56. file_path = os.path.join('/tmp/output', filename)
  57. with open(file_path, 'w') as file:
  58. file.write('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  59. test_loss, correct, data_length, 100. * correct / data_length))
  60. if __name__ == '__main__':
  61. args, unknown = parser.parse_known_args()
  62. #初始化导入数据集和预训练模型到容器内
  63. c2net_context = prepare()
  64. #获取数据集路径
  65. MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
  66. #获取预训练模型路径
  67. Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model"
  68. #log output
  69. print('cuda is available:{}'.format(torch.cuda.is_available()))
  70. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  71. batch_size = args.batch_size
  72. epochs = args.epoch_size
  73. test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False)
  74. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  75. model = Model().to(device)
  76. checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1.pkl")
  77. model.load_state_dict(checkpoint['model'])
  78. test(model,test_loader,len(test_dataset))

No Description