You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lenet5.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :lenet5.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/03/03
  9. # Description :
  10. #
  11. #================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torchvision
  15. import torch
  16. from torch import nn
  17. from torch.nn import functional as F
  18. from torch.autograd import Variable
  19. import torchvision.transforms as transforms
  20. from models.basic_model import BasicModel
  21. import utils.plog as plog
  22. class LeNet5(nn.Module):
  23. def __init__(self):
  24. super().__init__()
  25. self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
  26. self.conv2 = nn.Conv2d(6, 16, 3)
  27. self.conv3 = nn.Conv2d(16, 16, 3)
  28. self.fc1 = nn.Linear(256, 120)
  29. self.fc2 = nn.Linear(120, 84)
  30. self.fc3 = nn.Linear(84, 13)
  31. def forward(self, x):
  32. '''前向传播函数'''
  33. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  34. x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
  35. x = F.relu(self.conv3(x))
  36. x = x.view(-1, self.num_flat_features(x))
  37. #print(x.size())
  38. x = F.relu(self.fc1(x))
  39. x = F.relu(self.fc2(x))
  40. x = self.fc3(x)
  41. return x
  42. def num_flat_features(self, x):
  43. #x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size
  44. size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字
  45. num_features = 1
  46. for s in size:
  47. num_features *= s
  48. return num_features
  49. class Params:
  50. imgH = 28
  51. imgW = 28
  52. keep_ratio = True
  53. saveInterval = 10
  54. batchSize = 16
  55. num_workers = 16
  56. def get_data(): #数据预处理
  57. transform = transforms.Compose([transforms.ToTensor(),
  58. transforms.Normalize((0.5), (0.5))])
  59. #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  60. #训练集
  61. train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
  62. train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16)
  63. #测试集
  64. test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True)
  65. test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16)
  66. classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
  67. return train_loader, test_loader, classes
  68. if __name__ == "__main__":
  69. recorder = plog.ResultRecorder()
  70. cls = LeNet5()
  71. criterion = nn.CrossEntropyLoss(size_average=True)
  72. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  73. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  74. model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder)
  75. train_loader, test_loader, classes = get_data()
  76. #model.val(test_loader, print_prefix = "before training")
  77. model.fit(train_loader, n_epoch = 100)
  78. model.val(test_loader, print_prefix = "after trained")
  79. res = model.predict(test_loader, print_prefix = "predict")
  80. print(res.argmax(axis=1)[:10])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.