You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.2 kB

2 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from torchvision.transforms import transforms
  2. from torchvision.datasets import CIFAR10
  3. from torch.utils.data import DataLoader
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.optim import Adam
  8. from torch.autograd import Variable
  9. import argparse
  10. import time
  11. import shutil
  12. from c2net.context import prepare,upload_output
  13. parser = argparse.ArgumentParser(description='忽略超参数不存在的报错问题')
  14. #添加自定义参数
  15. parser.add_argument("--output")
  16. parser.add_argument('--epoch', type=int, default=1)
  17. parser.add_argument('--card', type=str, default='cuda:0')
  18. args = parser.parse_args()
  19. args, unknown = parser.parse_known_args()
  20. #初始化导入数据集和预训练模型到容器内
  21. c2net_context = prepare()
  22. # codePath = c2net_context.code_path
  23. # test = codePath + '/pytorch-cnn-cifar10-dcu' + '/test.py'
  24. #获取数据集路径
  25. cifar_10_python_path = c2net_context.dataset_path+"/"+"jo-161M"
  26. #输出结果必须保存在该目录
  27. outputPath = c2net_context.output_path
  28. #回传结果到openi,只有训练任务才能回传
  29. upload_output()
  30. class Network(nn.Module):
  31. def __init__(self):
  32. super(Network, self).__init__()
  33. self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)
  34. self.bn1 = nn.BatchNorm2d(12)
  35. self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
  36. self.bn2 = nn.BatchNorm2d(12)
  37. self.pool = nn.MaxPool2d(2, 2)
  38. self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
  39. self.bn4 = nn.BatchNorm2d(24)
  40. self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
  41. self.bn5 = nn.BatchNorm2d(24)
  42. self.fc1 = nn.Linear(24 * 10 * 10, 10)
  43. def forward(self, input):
  44. output = F.relu(self.bn1(self.conv1(input)))
  45. output = F.relu(self.bn2(self.conv2(output)))
  46. output = self.pool(output)
  47. output = F.relu(self.bn4(self.conv4(output)))
  48. output = F.relu(self.bn5(self.conv5(output)))
  49. output = output.view(-1, 24 * 10 * 10)
  50. output = self.fc1(output)
  51. return output
  52. def saveModel():
  53. path = outputPath + '/' + 'test.pth'
  54. torch.save(model.state_dict(), path)
  55. zipfileName = outputPath + '/' + 'test_database'
  56. save_zipfile(zipfileName, outputPath)
  57. def testAccuracy(card):
  58. model.eval()
  59. accuracy = 0.0
  60. total = 0.0
  61. # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  62. device = torch.device(card)
  63. model.to(device)
  64. test_dataloader = Get_dataloader(False)
  65. with torch.no_grad():
  66. for data in test_dataloader:
  67. images, labels = data
  68. images = Variable(images.to(device))
  69. labels = Variable(labels.to(device))
  70. outputs = model(images)
  71. _, predicted = torch.max(outputs.data, 1)
  72. total += labels.size(0)
  73. accuracy += (predicted == labels).sum().item()
  74. accuracy = (100 * accuracy / total)
  75. return (accuracy)
  76. def train(num_epochs, card):
  77. # best_accuracy = 0.0
  78. # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  79. device = torch.device(card)
  80. model.to(device)
  81. train_dataloader = Get_dataloader(True)
  82. for epoch in range(num_epochs):
  83. running_loss = 0.0
  84. running_acc = 0.0
  85. for i, (images, labels) in enumerate(train_dataloader, 0):
  86. images = Variable(images.to(device))
  87. labels = Variable(labels.to(device))
  88. optimizer.zero_grad()
  89. outputs = model(images)
  90. loss = loss_fn(outputs, labels)
  91. loss.backward()
  92. optimizer.step()
  93. running_loss += loss.item()
  94. if i % 1000 == 999:
  95. print('[%d, %5d] loss: %.3f' %
  96. (epoch + 1, i + 1, running_loss / 1000))
  97. running_loss = 0.0
  98. accuracy = testAccuracy(card)
  99. print('For epoch', epoch + 1, 'the test accuracy over the whole test set is %d %%' % (accuracy))
  100. # saveModel()
  101. def Get_dataloader(train):
  102. transform_fn = transforms.Compose([
  103. transforms.ToTensor(),
  104. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  105. ])
  106. dataset = CIFAR10(root=DATA_ROOT, train=train, transform=transform_fn)
  107. data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0)
  108. return data_loader
  109. def save_zipfile(filename, dest):
  110. shutil.make_archive(filename, 'zip', dest)
  111. DATA_ROOT = cifar_10_python_path
  112. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  113. batch_size = 10
  114. number_of_labels = 10
  115. model = Network()
  116. loss_fn = nn.CrossEntropyLoss()
  117. optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
  118. if __name__ == "__main__":
  119. accuracy = testAccuracy(args.card)
  120. print('before training, accuracy for test data is: ', accuracy)
  121. start = time.perf_counter()
  122. train(args.epoch, args.card)
  123. end = time.perf_counter()
  124. print(f"training completed in {end - start:0.4f} seconds")
  125. saveModel()

No Description