|
- # coding: utf-8
- #================================================================#
- # Copyright (C) 2020 Freecss All rights reserved.
- #
- # File Name :basic_model.py
- # Author :freecss
- # Email :karlfreecss@gmail.com
- # Created Date :2020/11/21
- # Description :
- #
- #================================================================#
-
- import sys
- sys.path.append("..")
-
- import torch
- from torch.autograd import Variable
- from torch.utils.data import Dataset
- import torchvision
-
- import os
- from multiprocessing import Pool
-
- import random
- import torch
- from torch.utils.data import Dataset
- from torch.utils.data import sampler
- import torchvision.transforms as transforms
- import six
- import sys
- from PIL import Image
- import numpy as np
- import collections
-
- class resizeNormalize(object):
-
- def __init__(self, size, interpolation=Image.BILINEAR):
- self.size = size
- self.interpolation = interpolation
- self.toTensor = transforms.ToTensor()
- self.transform = transforms.Compose([
- #transforms.ToPILImage(),
- #transforms.RandomHorizontalFlip(),
- #transforms.RandomVerticalFlip(),
- #transforms.RandomRotation(30),
- #transforms.RandomAffine(30),
- transforms.ToTensor(),
-
- ])
-
- def __call__(self, img):
- #img = img.resize(self.size, self.interpolation)
- #img = self.toTensor(img)
- img = self.transform(img)
- img.sub_(0.5).div_(0.5)
- return img
-
- class XYDataset(Dataset):
- def __init__(self, X, Y, transform=None, target_transform=None):
- self.X = X
- self.Y = Y
-
- self.n_sample = len(X)
- self.transform = transform
- self.target_transform = target_transform
-
- def __len__(self):
- return len(self.X)
-
- def __getitem__(self, index):
- assert index < len(self), 'index range error'
-
- img = self.X[index]
- if self.transform is not None:
- img = self.transform(img)
-
- label = self.Y[index]
- if self.target_transform is not None:
- label = self.target_transform(label)
-
- return (img, label, index)
-
- class alignCollate(object):
-
- def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
- self.imgH = imgH
- self.imgW = imgW
- self.keep_ratio = keep_ratio
- self.min_ratio = min_ratio
-
- def __call__(self, batch):
- images, labels, img_keys = zip(*batch)
-
- imgH = self.imgH
- imgW = self.imgW
- if self.keep_ratio:
- ratios = []
- for image in images:
- w, h = image.shape[:2]
- ratios.append(w / float(h))
- ratios.sort()
- max_ratio = ratios[-1]
- imgW = int(np.floor(max_ratio * imgH))
- imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
-
- transform = resizeNormalize((imgW, imgH))
- images = [transform(image) for image in images]
- images = torch.cat([t.unsqueeze(0) for t in images], 0)
- labels = torch.LongTensor(labels)
-
- return images, labels, img_keys
-
- class FakeRecorder():
- def __init__(self):
- pass
-
- def print(self, *x):
- pass
-
- from torch.nn import init
- from torch import nn
-
- def weigth_init(m):
- if isinstance(m, nn.Conv2d):
- init.xavier_uniform_(m.weight.data)
- init.constant_(m.bias.data,0.1)
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(0,0.01)
- m.bias.data.zero_()
-
- class BasicModel():
- def __init__(self,
- model,
- criterion,
- optimizer,
- device,
- params,
- sign_list,
- transform = None,
- target_transform=None,
- collate_fn = None,
- pretrained = False,
- recorder = None):
-
- self.model = model.to(device)
-
- self.criterion = criterion
- self.optimizer = optimizer
- self.transform = transform
- self.target_transform = target_transform
- self.device = device
-
- sign_list = sorted(list(set(sign_list)))
- self.mapping = dict(zip(sign_list, list(range(len(sign_list)))))
- self.remapping = dict(zip(list(range(len(sign_list))), sign_list))
-
- if recorder is None:
- recorder = FakeRecorder()
- self.recorder = recorder
-
- if pretrained:
- # the paths of model, optimizer should be included in params
- self.load(params.load_dir)
- else:
- self.model.apply(weigth_init)
-
- self.save_interval = params.saveInterval
- self.params = params
- self.collate_fn = collate_fn
- pass
-
- def _fit(self, data_loader, n_epoch, stop_loss):
- recorder = self.recorder
- recorder.print("model fitting")
-
- min_loss = 999999999
- for epoch in range(n_epoch):
- loss_value = self.train_epoch(data_loader)
- recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
- if loss_value < min_loss:
- min_loss = loss_value
- if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0:
- assert hasattr(self.params, 'save_dir')
- self.save(self.params.save_dir)
- if stop_loss is not None and loss_value < stop_loss:
- break
- recorder.print("Model fitted, minimal loss is ", min_loss)
- return loss_value
-
- def str2ints(self, Y):
- return [self.mapping[y] for y in Y]
-
- def fit(self, data_loader = None,
- X = None,
- y = None):
- if data_loader is None:
- params = self.params
- collate_fn = self.collate_fn
- transform = self.transform
- target_transform = self.target_transform
-
- Y = self.str2ints(y)
- train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
- sampler = None
- data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
- shuffle=True, sampler=sampler, num_workers=int(params.workers), \
- collate_fn=collate_fn)
- return self._fit(data_loader, params.n_epoch, params.stop_loss)
-
- def train_epoch(self, data_loader):
- # loss_avg = mutils.averager()
- self.model.train()
-
- loss_value = 0
- for i, data in enumerate(data_loader):
- X = data[0]
- Y = data[1]
- loss = self.train_batch(X, Y)
- loss_value += loss.item()
-
- return loss_value
-
- def train_batch(self, X, Y):
- #cpu_images, cpu_texts, _ = data
- model = self.model
- criterion = self.criterion
- optimizer = self.optimizer
- device = self.device
-
- # init training status
- # torch.autograd.set_detect_anomaly(True)
-
- # model predict
- X = X.to(device)
- Y = Y.to(device)
- pred_Y = model(X)
-
- # calculate loss
- loss = criterion(pred_Y, Y)
-
- # back propagation and optimize
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- return loss
-
- def _predict(self, data_loader):
- model = self.model
- device = self.device
-
- model.eval()
-
- with torch.no_grad():
- results = []
- for i, data in enumerate(data_loader):
- X = data[0].to(device)
- pred_Y = model(X)
- results.append(pred_Y)
-
- return torch.cat(results, axis=0)
-
- def predict(self, data_loader = None, X = None, print_prefix = ""):
- if data_loader is None:
- params = self.params
- collate_fn = self.collate_fn
- transform = self.transform
- target_transform = self.target_transform
-
- Y = [0] * len(X)
- val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
- sampler = None
- data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
- shuffle=False, sampler=sampler, num_workers=int(params.workers), \
- collate_fn=collate_fn)
-
- recorder = self.recorder
- recorder.print('Start Predict ', print_prefix)
- Y = self._predict(data_loader).argmax(axis=1)
- return [self.remapping[int(y)] for y in Y]
-
- def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
- if data_loader is None:
- params = self.params
- collate_fn = self.collate_fn
- transform = self.transform
- target_transform = self.target_transform
-
- Y = [0] * len(X)
- val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
- sampler = None
- data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
- shuffle=False, sampler=sampler, num_workers=int(params.workers), \
- collate_fn=collate_fn)
-
- recorder = self.recorder
- recorder.print('Start Predict ', print_prefix)
- return torch.softmax(self._predict(data_loader), axis=1)
-
- def _val(self, data_loader, print_prefix):
- model = self.model
- criterion = self.criterion
- recorder = self.recorder
- device = self.device
- recorder.print('Start val ', print_prefix)
-
- model.eval()
-
- n_correct = 0
- pred_num = 0
- loss_value = 0
- with torch.no_grad():
- for i, data in enumerate(data_loader):
- X = data[0].to(device)
- Y = data[1].to(device)
-
- pred_Y = model(X)
-
- correct_num = sum(Y == pred_Y.argmax(axis=1))
- loss = criterion(pred_Y, Y)
- loss_value += loss.item()
-
- n_correct += correct_num
- pred_num += len(X)
-
- accuracy = float(n_correct) / float(pred_num)
- recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy))
- return accuracy
-
- def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
- if data_loader is None:
- params = self.params
- collate_fn = self.collate_fn
- transform = self.transform
- target_transform = self.target_transform
-
- Y = self.str2ints(y)
- val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
- sampler = None
- data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
- shuffle=True, sampler=sampler, num_workers=int(params.workers), \
- collate_fn=collate_fn)
- return self._val(data_loader, print_prefix)
-
- def score(self, data_loader = None, X = None, y = None, print_prefix = ""):
- return self.val(data_loader, X, y, print_prefix)
-
- def save(self, save_dir):
- recorder = self.recorder
- if not os.path.exists(save_dir):
- os.mkdir(save_dir)
- recorder.print("Saving model and opter")
- save_path = os.path.join(save_dir, "net.pth")
- torch.save(self.model.state_dict(), save_path)
-
- save_path = os.path.join(save_dir, "opt.pth")
- torch.save(self.optimizer.state_dict(), save_path)
-
- def load(self, load_dir):
- recorder = self.recorder
- recorder.print("Loading model and opter")
- load_path = os.path.join(load_dir, "net.pth")
- self.model.load_state_dict(torch.load(load_path))
-
- load_path = os.path.join(load_dir, "opt.pth")
- self.optimizer.load_state_dict(torch.load(load_path))
-
- if __name__ == "__main__":
- pass
-
|