You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_gpu.py 4.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 1,The dataset structure of the single-dataset in this example
  8. MnistDataset_torch.zip
  9. ├── test
  10. └── train
  11. '''
  12. import os
  13. os.system("pip cache purge")
  14. os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH")))
  15. from model import Model
  16. import numpy as np
  17. import torch
  18. from torchvision.datasets import mnist
  19. from torch.nn import CrossEntropyLoss
  20. from torch.optim import SGD
  21. from torch.utils.data import DataLoader
  22. from torchvision.transforms import ToTensor
  23. import argparse
  24. #导入openi包
  25. from openi.context import prepare, upload_openi
  26. # Training settings
  27. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  28. parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
  29. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  30. # 参数声明
  31. WORKERS = 0 # dataloder线程数
  32. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  33. model = Model().to(device)
  34. optimizer = SGD(model.parameters(), lr=1e-1)
  35. cost = CrossEntropyLoss()
  36. # 模型训练
  37. def train(model, train_loader, epoch):
  38. model.train()
  39. train_loss = 0
  40. for i, data in enumerate(train_loader, 0):
  41. x, y = data
  42. x = x.to(device)
  43. y = y.to(device)
  44. optimizer.zero_grad()
  45. y_hat = model(x)
  46. loss = cost(y_hat, y)
  47. loss.backward()
  48. optimizer.step()
  49. train_loss += loss
  50. loss_mean = train_loss / (i+1)
  51. print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
  52. # 模型测试
  53. def test(model, test_loader, test_data):
  54. model.eval()
  55. test_loss = 0
  56. correct = 0
  57. with torch.no_grad():
  58. for i, data in enumerate(test_loader, 0):
  59. x, y = data
  60. x = x.to(device)
  61. y = y.to(device)
  62. optimizer.zero_grad()
  63. y_hat = model(x)
  64. test_loss += cost(y_hat, y).item()
  65. pred = y_hat.max(1, keepdim=True)[1]
  66. correct += pred.eq(y.view_as(pred)).sum().item()
  67. test_loss /= (i+1)
  68. print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  69. test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  70. if __name__ == '__main__':
  71. args, unknown = parser.parse_known_args()
  72. #初始化导入数据集和预训练模型到容器内
  73. openi_context = prepare()
  74. #获取数据集路径,预训练模型路径,输出路径
  75. dataset_path = openi_context.dataset_path
  76. pretrain_model_path = openi_context.pretrain_model_path
  77. output_path = openi_context.output_path
  78. print("dataset_path:")
  79. print(os.listdir(dataset_path))
  80. os.listdir(dataset_path)
  81. print("pretrain_model_path:")
  82. print(os.listdir(pretrain_model_path))
  83. os.listdir(pretrain_model_path)
  84. print("output_path:")
  85. print(os.listdir(output_path))
  86. os.listdir(output_path)
  87. #log output
  88. print('cuda is available:{}'.format(torch.cuda.is_available()))
  89. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  90. batch_size = args.batch_size
  91. epochs = args.epoch_size
  92. train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False)
  93. test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False)
  94. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  95. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  96. #如果有保存的模型,则加载模型,并在其基础上继续训练
  97. if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")):
  98. checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl"))
  99. model.load_state_dict(checkpoint['model'])
  100. optimizer.load_state_dict(checkpoint['optimizer'])
  101. start_epoch = checkpoint['epoch']
  102. print('加载 epoch {} 权重成功!'.format(start_epoch))
  103. else:
  104. start_epoch = 0
  105. print('无保存模型,将从头开始训练!')
  106. for epoch in range(start_epoch+1, epochs):
  107. train(model, train_loader, epoch)
  108. test(model, test_loader, test_dataset)
  109. # 保存模型
  110. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
  111. torch.save(state, '{}/mnist_epoch{}.pkl'.format(output_path, epoch))

No Description