You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_regression.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from PIL import Image
  2. from tqdm import tqdm
  3. import argparse
  4. import torch
  5. import torch.nn as nn
  6. import torchvision
  7. from torch.utils.data import DataLoader
  8. from torchvision import transforms
  9. from sklearn.metrics import mean_squared_error
  10. from sedna.common.config import Context, BaseConfig
  11. from sedna.datasources import TxtDataParse
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. def preprocess(train_data):
  14. x_data, y_data = train_data.x, train_data.y
  15. # preprocess label
  16. y_data = list(map(lambda y: [float(y)], y_data))
  17. y_data = torch.tensor(y_data)
  18. # preprocess images
  19. transformed_images = []
  20. for img_url in x_data:
  21. img = Image.open(img_url).convert('RGB')
  22. transformation = transforms.Compose([
  23. transforms.ToTensor(),
  24. transforms.Normalize(mean=(0.485, 0.456, 0.406),
  25. std=(0.229, 0.224, 0.225))
  26. ])
  27. img = transformation(img).unsqueeze(0).to(device)
  28. transformed_images.append(img[0])
  29. return transformed_images, y_data
  30. def set_args(**kwargs):
  31. parser = argparse.ArgumentParser(description="NeuralNetworkRegression Training")
  32. parser.add_argument("--learning-rate", type=float, default=0.001)
  33. parser.add_argument("--batch-size", type=int, default=20)
  34. parser.add_argument("--test-batch-size", type=int, default=1)
  35. parser.add_argument("--num-epoch", type=int, default=200)
  36. parser.add_argument("--hidden-size", type=int, default=32)
  37. parser.add_argument("--cuda", action="store_true", default=torch.cuda.is_available())
  38. args = parser.parse_args()
  39. return args
  40. class NNRregressionNet(nn.Module):
  41. def __init__(self, backbone, hidden_size):
  42. super(NNRregressionNet, self).__init__()
  43. self.backbone = backbone
  44. self.fc1 = nn.Linear(1000, hidden_size)
  45. self.relu1 = nn.ReLU()
  46. self.fc2 = nn.Linear(hidden_size, 1)
  47. self.relu2 = nn.ReLU()
  48. def forward(self, x):
  49. out = self.backbone(x)
  50. out = self.fc1(out)
  51. out = self.relu1(out)
  52. out = self.fc2(out)
  53. return out
  54. class NNRegression():
  55. def __init__(self, args):
  56. backbone = torchvision.models.resnet18(pretrained=True).to(device)
  57. self.args = args
  58. self.model = NNRregressionNet(backbone, args.hidden_size).to(device)
  59. self.criterion = nn.CrossEntropyLoss()
  60. self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
  61. def train(self, train_data, valid_data=None):
  62. x_data, y_data = preprocess(train_data)
  63. train_data = list(zip(x_data, y_data))
  64. train_data_loader = DataLoader(train_data, batch_size=self.args.batch_size, shuffle=True)
  65. for epoch in range(self.args.num_epoch):
  66. self.model.train()
  67. for i, (images, labels) in enumerate(tqdm(train_data_loader)):
  68. if self.args.cuda:
  69. images = images.cuda()
  70. labels = labels.cuda()
  71. # forward pass
  72. # 完整的模型为self.model
  73. # 假设forward只涉及第n+1层以后
  74. new_model = nn.Sequential(*list(self.model.children())[n+1:]) # 将模型切片,不确定这样切片对不对,可以尝试其他
  75. outputs = self.model(images)
  76. loss = self.criterion(outputs, labels)
  77. # backward and optimize
  78. self.optimizer.zero_grad()
  79. loss.backward()
  80. self.optimizer.step()
  81. def predict(self, data):
  82. self.model.eval()
  83. prediction = []
  84. test_data_loader = DataLoader(data, batch_size=self.args.test_batch_size, shuffle=False)
  85. for image in tqdm(test_data_loader):
  86. if self.args.cuda:
  87. image = image.cuda()
  88. with torch.no_grad():
  89. output = self.model(image)
  90. output = output.data.cpu().numpy()[0]
  91. prediction.append(output)
  92. return prediction
  93. def eval(self, data, metric=None):
  94. x_data, y_data = preprocess(data)
  95. y_data = y_data.data.cpu().numpy()
  96. prediction = self.predict(x_data)
  97. score = metric(y_data, prediction)
  98. return score
  99. def save(self, model_name):
  100. torch.save(self.model.state_dict(), model_name)
  101. return model_name
  102. def load(self, model_path):
  103. self.model.load_state_dict(torch.load(model_path))
  104. def train():
  105. train_dataset_url = "./data_txt/regression_train.txt"
  106. train_data = TxtDataParse(data_type="train")
  107. train_data.parse(train_dataset_url, use_raw=False)
  108. args = set_args()
  109. regressor = NNRegression(args)
  110. regressor.train(train_data)
  111. regressor.save("./models/nn_regression.pth")
  112. def eval():
  113. test_dataset_url = "./data_txt/regression_test.txt"
  114. test_data = TxtDataParse(data_type="eval")
  115. test_data.parse(test_dataset_url, use_raw=False)
  116. args = set_args()
  117. regressor = NNRegression(args)
  118. regressor.load("./models/nn_regression.pth")
  119. eval_result = regressor.eval(test_data, metric=mean_squared_error)
  120. print("MSE:", eval_result)
  121. if __name__ == '__main__':
  122. eval()