You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. #================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. from torch.autograd import Variable
  16. from torch.utils.data import Dataset
  17. import torchvision
  18. import os
  19. from multiprocessing import Pool
  20. import random
  21. import torch
  22. from torch.utils.data import Dataset
  23. from torch.utils.data import sampler
  24. import torchvision.transforms as transforms
  25. import six
  26. import sys
  27. from PIL import Image
  28. import numpy as np
  29. import collections
  30. class resizeNormalize(object):
  31. def __init__(self, size, interpolation=Image.BILINEAR):
  32. self.size = size
  33. self.interpolation = interpolation
  34. self.toTensor = transforms.ToTensor()
  35. self.transform = transforms.Compose([
  36. #transforms.ToPILImage(),
  37. #transforms.RandomHorizontalFlip(),
  38. #transforms.RandomVerticalFlip(),
  39. #transforms.RandomRotation(30),
  40. #transforms.RandomAffine(30),
  41. transforms.ToTensor(),
  42. ])
  43. def __call__(self, img):
  44. #img = img.resize(self.size, self.interpolation)
  45. #img = self.toTensor(img)
  46. img = self.transform(img)
  47. img.sub_(0.5).div_(0.5)
  48. return img
  49. class XYDataset(Dataset):
  50. def __init__(self, X, Y, transform=None, target_transform=None):
  51. self.X = X
  52. self.Y = Y
  53. self.n_sample = len(X)
  54. self.transform = transform
  55. self.target_transform = target_transform
  56. def __len__(self):
  57. return len(self.X)
  58. def __getitem__(self, index):
  59. assert index < len(self), 'index range error'
  60. img = self.X[index]
  61. if self.transform is not None:
  62. img = self.transform(img)
  63. label = self.Y[index]
  64. if self.target_transform is not None:
  65. label = self.target_transform(label)
  66. return (img, label, index)
  67. class alignCollate(object):
  68. def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
  69. self.imgH = imgH
  70. self.imgW = imgW
  71. self.keep_ratio = keep_ratio
  72. self.min_ratio = min_ratio
  73. def __call__(self, batch):
  74. images, labels, img_keys = zip(*batch)
  75. imgH = self.imgH
  76. imgW = self.imgW
  77. if self.keep_ratio:
  78. ratios = []
  79. for image in images:
  80. w, h = image.shape[:2]
  81. ratios.append(w / float(h))
  82. ratios.sort()
  83. max_ratio = ratios[-1]
  84. imgW = int(np.floor(max_ratio * imgH))
  85. imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
  86. transform = resizeNormalize((imgW, imgH))
  87. images = [transform(image) for image in images]
  88. images = torch.cat([t.unsqueeze(0) for t in images], 0)
  89. labels = torch.LongTensor(labels)
  90. return images, labels, img_keys
  91. class FakeRecorder():
  92. def __init__(self):
  93. pass
  94. def print(self, *x):
  95. pass
  96. from torch.nn import init
  97. from torch import nn
  98. def weigth_init(m):
  99. if isinstance(m, nn.Conv2d):
  100. init.xavier_uniform_(m.weight.data)
  101. init.constant_(m.bias.data,0.1)
  102. elif isinstance(m, nn.BatchNorm2d):
  103. m.weight.data.fill_(1)
  104. m.bias.data.zero_()
  105. elif isinstance(m, nn.Linear):
  106. m.weight.data.normal_(0,0.01)
  107. m.bias.data.zero_()
  108. class BasicModel():
  109. def __init__(self,
  110. model,
  111. criterion,
  112. optimizer,
  113. device,
  114. params,
  115. sign_list,
  116. transform = None,
  117. target_transform=None,
  118. collate_fn = None,
  119. pretrained = False,
  120. recorder = None):
  121. self.model = model.to(device)
  122. self.criterion = criterion
  123. self.optimizer = optimizer
  124. self.transform = transform
  125. self.target_transform = target_transform
  126. self.device = device
  127. sign_list = sorted(list(set(sign_list)))
  128. self.mapping = dict(zip(sign_list, list(range(len(sign_list)))))
  129. self.remapping = dict(zip(list(range(len(sign_list))), sign_list))
  130. if recorder is None:
  131. recorder = FakeRecorder()
  132. self.recorder = recorder
  133. if pretrained:
  134. # the paths of model, optimizer should be included in params
  135. self.load(params.load_dir)
  136. else:
  137. self.model.apply(weigth_init)
  138. self.save_interval = params.saveInterval
  139. self.params = params
  140. self.collate_fn = collate_fn
  141. pass
  142. def _fit(self, data_loader, n_epoch, stop_loss):
  143. recorder = self.recorder
  144. recorder.print("model fitting")
  145. min_loss = 999999999
  146. for epoch in range(n_epoch):
  147. loss_value = self.train_epoch(data_loader)
  148. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  149. if loss_value < min_loss:
  150. min_loss = loss_value
  151. if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0:
  152. assert hasattr(self.params, 'save_dir')
  153. self.save(self.params.save_dir)
  154. if stop_loss is not None and loss_value < stop_loss:
  155. break
  156. recorder.print("Model fitted, minimal loss is ", min_loss)
  157. return loss_value
  158. def str2ints(self, Y):
  159. return [self.mapping[y] for y in Y]
  160. def fit(self, data_loader = None,
  161. X = None,
  162. y = None):
  163. if data_loader is None:
  164. params = self.params
  165. collate_fn = self.collate_fn
  166. transform = self.transform
  167. target_transform = self.target_transform
  168. Y = self.str2ints(y)
  169. train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  170. sampler = None
  171. data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
  172. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  173. collate_fn=collate_fn)
  174. return self._fit(data_loader, params.n_epoch, params.stop_loss)
  175. def train_epoch(self, data_loader):
  176. # loss_avg = mutils.averager()
  177. self.model.train()
  178. loss_value = 0
  179. for i, data in enumerate(data_loader):
  180. X = data[0]
  181. Y = data[1]
  182. loss = self.train_batch(X, Y)
  183. loss_value += loss.item()
  184. return loss_value
  185. def train_batch(self, X, Y):
  186. #cpu_images, cpu_texts, _ = data
  187. model = self.model
  188. criterion = self.criterion
  189. optimizer = self.optimizer
  190. device = self.device
  191. # init training status
  192. # torch.autograd.set_detect_anomaly(True)
  193. # model predict
  194. X = X.to(device)
  195. Y = Y.to(device)
  196. pred_Y = model(X)
  197. # calculate loss
  198. loss = criterion(pred_Y, Y)
  199. # back propagation and optimize
  200. optimizer.zero_grad()
  201. loss.backward()
  202. optimizer.step()
  203. return loss
  204. def _predict(self, data_loader):
  205. model = self.model
  206. device = self.device
  207. model.eval()
  208. with torch.no_grad():
  209. results = []
  210. for i, data in enumerate(data_loader):
  211. X = data[0].to(device)
  212. pred_Y = model(X)
  213. results.append(pred_Y)
  214. return torch.cat(results, axis=0)
  215. def predict(self, data_loader = None, X = None, print_prefix = ""):
  216. if data_loader is None:
  217. params = self.params
  218. collate_fn = self.collate_fn
  219. transform = self.transform
  220. target_transform = self.target_transform
  221. Y = [0] * len(X)
  222. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  223. sampler = None
  224. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  225. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  226. collate_fn=collate_fn)
  227. recorder = self.recorder
  228. recorder.print('Start Predict ', print_prefix)
  229. Y = self._predict(data_loader).argmax(axis=1)
  230. return [self.remapping[int(y)] for y in Y]
  231. def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
  232. if data_loader is None:
  233. params = self.params
  234. collate_fn = self.collate_fn
  235. transform = self.transform
  236. target_transform = self.target_transform
  237. Y = [0] * len(X)
  238. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  239. sampler = None
  240. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  241. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  242. collate_fn=collate_fn)
  243. recorder = self.recorder
  244. recorder.print('Start Predict ', print_prefix)
  245. return torch.softmax(self._predict(data_loader), axis=1)
  246. def _val(self, data_loader, print_prefix):
  247. model = self.model
  248. criterion = self.criterion
  249. recorder = self.recorder
  250. device = self.device
  251. recorder.print('Start val ', print_prefix)
  252. model.eval()
  253. n_correct = 0
  254. pred_num = 0
  255. loss_value = 0
  256. with torch.no_grad():
  257. for i, data in enumerate(data_loader):
  258. X = data[0].to(device)
  259. Y = data[1].to(device)
  260. pred_Y = model(X)
  261. correct_num = sum(Y == pred_Y.argmax(axis=1))
  262. loss = criterion(pred_Y, Y)
  263. loss_value += loss.item()
  264. n_correct += correct_num
  265. pred_num += len(X)
  266. accuracy = float(n_correct) / float(pred_num)
  267. recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy))
  268. return accuracy
  269. def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
  270. if data_loader is None:
  271. params = self.params
  272. collate_fn = self.collate_fn
  273. transform = self.transform
  274. target_transform = self.target_transform
  275. Y = self.str2ints(y)
  276. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  277. sampler = None
  278. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  279. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  280. collate_fn=collate_fn)
  281. return self._val(data_loader, print_prefix)
  282. def score(self, data_loader = None, X = None, y = None, print_prefix = ""):
  283. return self.val(data_loader, X, y, print_prefix)
  284. def save(self, save_dir):
  285. recorder = self.recorder
  286. if not os.path.exists(save_dir):
  287. os.mkdir(save_dir)
  288. recorder.print("Saving model and opter")
  289. save_path = os.path.join(save_dir, "net.pth")
  290. torch.save(self.model.state_dict(), save_path)
  291. save_path = os.path.join(save_dir, "opt.pth")
  292. torch.save(self.optimizer.state_dict(), save_path)
  293. def load(self, load_dir):
  294. recorder = self.recorder
  295. recorder.print("Loading model and opter")
  296. load_path = os.path.join(load_dir, "net.pth")
  297. self.model.load_state_dict(torch.load(load_path))
  298. load_path = os.path.join(load_dir, "opt.pth")
  299. self.optimizer.load_state_dict(torch.load(load_path))
  300. if __name__ == "__main__":
  301. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.