from torchvision.transforms import transforms from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import Adam from torch.autograd import Variable import argparse import time import shutil from c2net.context import prepare,upload_output parser = argparse.ArgumentParser(description='忽略超参数不存在的报错问题') #添加自定义参数 parser.add_argument("--output") parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--card', type=str, default='cuda:0') args = parser.parse_args() args, unknown = parser.parse_known_args() #初始化导入数据集和预训练模型到容器内 c2net_context = prepare() # codePath = c2net_context.code_path # test = codePath + '/pytorch-cnn-cifar10-dcu' + '/test.py' #获取数据集路径 cifar_10_python_path = c2net_context.dataset_path+"/"+"jo-161M" #输出结果必须保存在该目录 outputPath = c2net_context.output_path #回传结果到openi,只有训练任务才能回传 upload_output() class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(12) self.pool = nn.MaxPool2d(2, 2) self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1) self.bn4 = nn.BatchNorm2d(24) self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1) self.bn5 = nn.BatchNorm2d(24) self.fc1 = nn.Linear(24 * 10 * 10, 10) def forward(self, input): output = F.relu(self.bn1(self.conv1(input))) output = F.relu(self.bn2(self.conv2(output))) output = self.pool(output) output = F.relu(self.bn4(self.conv4(output))) output = F.relu(self.bn5(self.conv5(output))) output = output.view(-1, 24 * 10 * 10) output = self.fc1(output) return output def saveModel(): path = outputPath + '/' + 'test.pth' torch.save(model.state_dict(), path) zipfileName = outputPath + '/' + 'test_database' save_zipfile(zipfileName, outputPath) def testAccuracy(card): model.eval() accuracy = 0.0 total = 0.0 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device(card) model.to(device) test_dataloader = Get_dataloader(False) with torch.no_grad(): for data in test_dataloader: images, labels = data images = Variable(images.to(device)) labels = Variable(labels.to(device)) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) accuracy += (predicted == labels).sum().item() accuracy = (100 * accuracy / total) return (accuracy) def train(num_epochs, card): # best_accuracy = 0.0 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device(card) model.to(device) train_dataloader = Get_dataloader(True) for epoch in range(num_epochs): running_loss = 0.0 running_acc = 0.0 for i, (images, labels) in enumerate(train_dataloader, 0): images = Variable(images.to(device)) labels = Variable(labels.to(device)) optimizer.zero_grad() outputs = model(images) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 1000 == 999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000)) running_loss = 0.0 accuracy = testAccuracy(card) print('For epoch', epoch + 1, 'the test accuracy over the whole test set is %d %%' % (accuracy)) # saveModel() def Get_dataloader(train): transform_fn = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = CIFAR10(root=DATA_ROOT, train=train, transform=transform_fn) data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0) return data_loader def save_zipfile(filename, dest): shutil.make_archive(filename, 'zip', dest) DATA_ROOT = cifar_10_python_path classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') batch_size = 10 number_of_labels = 10 model = Network() loss_fn = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001) if __name__ == "__main__": accuracy = testAccuracy(args.card) print('before training, accuracy for test data is: ', accuracy) start = time.perf_counter() train(args.epoch, args.card) end = time.perf_counter() print(f"training completed in {end - start:0.4f} seconds") saveModel()