You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.2 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
1 year ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. └── train
  12. 预训练模型文件夹结构是:
  13. Torch_MNIST_Example_Model
  14. ├── mnist_epoch1.pkl
  15. '''
  16. import torch
  17. from model import Model
  18. import numpy as np
  19. from torchvision.datasets import mnist
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from torch.utils.data import DataLoader
  23. from torchvision.transforms import ToTensor
  24. import argparse
  25. import os
  26. #导入c2net包
  27. from c2net.context import prepare, upload_output
  28. import importlib.util
  29. def is_torch_dtu_available():
  30. if importlib.util.find_spec("torch_dtu") is None:
  31. return False
  32. if importlib.util.find_spec("torch_dtu.core") is None:
  33. return False
  34. return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None
  35. # Training settings
  36. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  37. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  38. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  39. if __name__ == '__main__':
  40. args, unknown = parser.parse_known_args()
  41. #初始化导入数据集和预训练模型到容器内
  42. c2net_context = prepare()
  43. #获取数据集路径
  44. MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
  45. #获取预训练模型路径
  46. Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model"
  47. #获取输出路径
  48. output_path = c2net_context.output_path
  49. # load DPU envs-xx.sh
  50. DTU_FLAG = True
  51. if is_torch_dtu_available():
  52. import torch_dtu
  53. import torch_dtu.distributed as dist
  54. import torch_dtu.core.dtu_model as dm
  55. from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP
  56. print('dtu is available: True')
  57. device = dm.dtu_device()
  58. DTU_FLAG = True
  59. else:
  60. print('dtu is available: False')
  61. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  62. DTU_FLAG = False
  63. # 参数声明
  64. model = Model().to(device)
  65. optimizer = SGD(model.parameters(), lr=1e-1)
  66. args, unknown = parser.parse_known_args()
  67. #log output
  68. batch_size = args.batch_size
  69. train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "train"), train=True, transform=ToTensor(),download=False)
  70. test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "test"), train=False, transform=ToTensor(),download=False)
  71. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  72. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  73. model = Model().to(device)
  74. sgd = SGD(model.parameters(), lr=1e-1)
  75. cost = CrossEntropyLoss()
  76. epochs = args.epoch_size
  77. print('epoch_size is:{}'.format(epochs))
  78. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  79. if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")):
  80. checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl"))
  81. model.load_state_dict(checkpoint['model'])
  82. optimizer.load_state_dict(checkpoint['optimizer'])
  83. start_epoch = checkpoint['epoch']
  84. print('加载 epoch {} 权重成功!'.format(start_epoch))
  85. else:
  86. start_epoch = 0
  87. print('无保存模型,将从头开始训练!')
  88. for _epoch in range(start_epoch, epochs):
  89. print('the {} epoch_size begin'.format(_epoch + 1))
  90. model.train()
  91. for idx, (train_x, train_label) in enumerate(train_loader):
  92. train_x = train_x.to(device)
  93. train_label = train_label.to(device)
  94. label_np = np.zeros((train_label.shape[0], 10))
  95. sgd.zero_grad()
  96. predict_y = model(train_x.float())
  97. loss = cost(predict_y, train_label.long())
  98. if idx % 10 == 0:
  99. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  100. loss.backward()
  101. if DTU_FLAG:
  102. dm.optimizer_step(sgd, barrier=True)
  103. else:
  104. sgd.step()
  105. correct = 0
  106. _sum = 0
  107. model.eval()
  108. for idx, (test_x, test_label) in enumerate(test_loader):
  109. test_x = test_x
  110. test_label = test_label
  111. predict_y = model(test_x.to(device).float()).detach()
  112. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  113. label_np = test_label.numpy()
  114. _ = predict_ys == test_label
  115. correct += np.sum(_.numpy(), axis=-1)
  116. _sum += _.shape[0]
  117. print('accuracy: {:.2f}'.format(correct / _sum))
  118. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1}
  119. torch.save(state, '{}/mnist_epoch{}_{:.2f}.pkl'.format(output_path, _epoch+1, correct / _sum))
  120. print(os.listdir('{}'.format(output_path)))

No Description