You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_train.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 数据集结构是:
  8. MnistDataset_torch.zip
  9. ├── test
  10. └── train
  11. 预训练模型文件夹结构是:
  12. Torch_MNIST_Example_Model
  13. ├── mnist_epoch1.pkl
  14. '''
  15. from model import Model
  16. import numpy as np
  17. import torch
  18. import torch.nn as nn
  19. from torchvision.datasets import mnist
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from torch.utils.data import DataLoader
  23. from torchvision.transforms import ToTensor
  24. import argparse
  25. import os
  26. os.system("pip install c2net")
  27. #导入c2net包
  28. from c2net.context import prepare
  29. # Training settings
  30. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  31. parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
  32. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  33. # 参数声明
  34. WORKERS = 0 # dataloder线程数
  35. # 检查可用GPU数量
  36. if torch.cuda.device_count() < 2:
  37. raise RuntimeError("需要至少2块GPU,但当前只有 {} 块".format(torch.cuda.device_count()))
  38. else:
  39. print('当前有 {} 块GPU'.format(torch.cuda.device_count()))
  40. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  41. # 2. 初始化模型并并行化
  42. model = Model().to(device)
  43. print('开始进行并行训练!')
  44. model = nn.DataParallel(model, device_ids=[0, 1]) # 使用GPU 0和1
  45. optimizer = SGD(model.parameters(), lr=1e-1)
  46. cost = CrossEntropyLoss()
  47. # 模型训练
  48. def train(model, train_loader, epoch):
  49. model.train()
  50. train_loss = 0
  51. for i, data in enumerate(train_loader, 0):
  52. x, y = data
  53. x = x.to(device)
  54. y = y.to(device)
  55. optimizer.zero_grad()
  56. y_hat = model(x)
  57. loss = cost(y_hat, y)
  58. loss.backward()
  59. optimizer.step()
  60. train_loss += loss
  61. loss_mean = train_loss / (i+1)
  62. print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
  63. # 模型测试
  64. def test(model, test_loader, test_data):
  65. model.eval()
  66. test_loss = 0
  67. correct = 0
  68. with torch.no_grad():
  69. for i, data in enumerate(test_loader, 0):
  70. x, y = data
  71. x = x.to(device)
  72. y = y.to(device)
  73. optimizer.zero_grad()
  74. y_hat = model(x)
  75. test_loss += cost(y_hat, y).item()
  76. pred = y_hat.max(1, keepdim=True)[1]
  77. correct += pred.eq(y.view_as(pred)).sum().item()
  78. test_loss /= (i+1)
  79. print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  80. test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  81. if __name__ == '__main__':
  82. args, unknown = parser.parse_known_args()
  83. #初始化导入数据集和预训练模型到容器内
  84. c2net_context = prepare()
  85. #获取数据集路径
  86. MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
  87. #获取预训练模型路径
  88. Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model"
  89. #log output
  90. print('cuda is available:{}'.format(torch.cuda.is_available()))
  91. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  92. batch_size = args.batch_size
  93. epochs = args.epoch_size
  94. train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "train"), train=True, transform=ToTensor(),download=False)
  95. # test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "test"), train=False, transform=ToTensor(),download=False)
  96. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  97. # test_loader = DataLoader(test_dataset, batch_size=batch_size)
  98. #如果有保存的模型,则加载模型,并在其基础上继续训练
  99. if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")):
  100. checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl"))
  101. model.load_state_dict(checkpoint['model'])
  102. optimizer.load_state_dict(checkpoint['optimizer'])
  103. start_epoch = checkpoint['epoch']
  104. print('加载 epoch {} 权重成功!'.format(start_epoch))
  105. else:
  106. start_epoch = 0
  107. print('无保存模型,将从头开始训练!')
  108. for epoch in range(start_epoch+1, epochs+1):
  109. train(model, train_loader, epoch)
  110. # test(model, test_loader, test_dataset)
  111. # 将模型保存到c2net_context.output_path
  112. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
  113. torch.save(state, '{}/mnist_epoch{}.pkl'.format(c2net_context.output_path, epoch))

No Description