You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_gcu.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. │ ├── MNIST/processed/test.pt
  12. │ └── MNIST/processed/training.pt
  13. │ ├── MNIST/raw/train-images-idx3-ubyte
  14. │ └── MNIST/raw/train-labels-idx1-ubyte
  15. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  16. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  17. ├── train
  18. │ ├── MNIST/processed/test.pt
  19. │ └── MNIST/processed/training.pt
  20. │ ├── MNIST/raw/train-images-idx3-ubyte
  21. │ └── MNIST/raw/train-labels-idx1-ubyte
  22. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  23. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  24. 示例选用的预训练模型文件为:mnist_epoch1_0.86.pkl
  25. '''
  26. import os
  27. print("begin:")
  28. os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH")))
  29. import torch
  30. from model import Model
  31. import numpy as np
  32. from torchvision.datasets import mnist
  33. from torch.nn import CrossEntropyLoss
  34. from torch.optim import SGD
  35. from torch.utils.data import DataLoader
  36. from torchvision.transforms import ToTensor
  37. import argparse
  38. from openi.context import prepare, upload_openi
  39. import importlib.util
  40. def is_torch_dtu_available():
  41. if importlib.util.find_spec("torch_dtu") is None:
  42. return False
  43. if importlib.util.find_spec("torch_dtu.core") is None:
  44. return False
  45. return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None
  46. # Training settings
  47. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  48. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  49. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  50. if __name__ == '__main__':
  51. #获取参数并忽略超参数报错
  52. args, unknown = parser.parse_known_args()
  53. #初始化导入数据集和预训练模型到容器内
  54. openi_context = prepare()
  55. #获取数据集路径,预训练模型路径,输出路径
  56. dataset_path = openi_context.dataset_path
  57. pretrain_model_path = openi_context.pretrain_model_path
  58. output_path = openi_context.output_path
  59. dataset_path_A = dataset_path + "/MnistDataset"
  60. pretrain_model_path_A = pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j"
  61. print("dataset_path:")
  62. print(os.listdir(dataset_path))
  63. os.listdir(dataset_path)
  64. print("pretrain_model_path:")
  65. print(os.listdir(pretrain_model_path))
  66. os.listdir(pretrain_model_path)
  67. print("output_path:")
  68. print(os.listdir(output_path))
  69. os.listdir(output_path)
  70. # load DPU envs-xx.sh
  71. DTU_FLAG = True
  72. if is_torch_dtu_available():
  73. import torch_dtu
  74. import torch_dtu.distributed as dist
  75. import torch_dtu.core.dtu_model as dm
  76. from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP
  77. print('dtu is available: True')
  78. device = dm.dtu_device()
  79. DTU_FLAG = True
  80. else:
  81. print('dtu is available: False')
  82. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  83. DTU_FLAG = False
  84. # 参数声明
  85. model = Model().to(device)
  86. optimizer = SGD(model.parameters(), lr=1e-1)
  87. #log output
  88. batch_size = args.batch_size
  89. train_dataset = mnist.MNIST(root=dataset_path_A + "/train", train=True, transform=ToTensor(),download=False)
  90. test_dataset = mnist.MNIST(root=dataset_path_A + "/test", train=False, transform=ToTensor(),download=False)
  91. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  92. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  93. model = Model().to(device)
  94. sgd = SGD(model.parameters(), lr=1e-1)
  95. cost = CrossEntropyLoss()
  96. epochs = args.epoch_size
  97. print('epoch_size is:{}'.format(epochs))
  98. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  99. if os.path.exists(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl"):
  100. checkpoint = torch.load(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl")
  101. model.load_state_dict(checkpoint['model'])
  102. optimizer.load_state_dict(checkpoint['optimizer'])
  103. start_epoch = checkpoint['epoch']
  104. print('加载 epoch {} 权重成功!'.format(start_epoch))
  105. else:
  106. start_epoch = 0
  107. print('无保存模型,将从头开始训练!')
  108. for _epoch in range(start_epoch, epochs):
  109. print('the {} epoch_size begin'.format(_epoch + 1))
  110. model.train()
  111. for idx, (train_x, train_label) in enumerate(train_loader):
  112. train_x = train_x.to(device)
  113. train_label = train_label.to(device)
  114. label_np = np.zeros((train_label.shape[0], 10))
  115. sgd.zero_grad()
  116. predict_y = model(train_x.float())
  117. loss = cost(predict_y, train_label.long())
  118. if idx % 10 == 0:
  119. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  120. loss.backward()
  121. if DTU_FLAG:
  122. dm.optimizer_step(sgd, barrier=True)
  123. else:
  124. sgd.step()
  125. correct = 0
  126. _sum = 0
  127. model.eval()
  128. for idx, (test_x, test_label) in enumerate(test_loader):
  129. test_x = test_x
  130. test_label = test_label
  131. predict_y = model(test_x.to(device).float()).detach()
  132. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  133. label_np = test_label.numpy()
  134. _ = predict_ys == test_label
  135. correct += np.sum(_.numpy(), axis=-1)
  136. _sum += _.shape[0]
  137. print('accuracy: {:.2f}'.format(correct / _sum))
  138. #The model output location is placed under /tmp/output
  139. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1}
  140. torch.save(state, '/tmp/output/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))
  141. print('test:')
  142. print(os.listdir("/tmp/output"))

No Description