From 7f29b79aeeb7982410fbc9aafc7e3b090524c7ce Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 15 Nov 2022 23:25:17 +0800 Subject: [PATCH] modify wabl_models.py and basic_model.py --- models/basic_model.py | 139 +++++++++++++++++++++++------------------- models/lenet5.py | 2 +- models/wabl_models.py | 62 ++++++------------- 3 files changed, 94 insertions(+), 109 deletions(-) diff --git a/models/basic_model.py b/models/basic_model.py index c54bb22..47aeab9 100644 --- a/models/basic_model.py +++ b/models/basic_model.py @@ -18,8 +18,6 @@ from torch.autograd import Variable from torch.utils.data import Dataset import torchvision -import utils.utils as mutils - import os from multiprocessing import Pool @@ -121,6 +119,7 @@ class FakeRecorder(): from torch.nn import init from torch import nn + def weigth_init(m): if isinstance(m, nn.Conv2d): init.xavier_uniform_(m.weight.data) @@ -137,18 +136,23 @@ class BasicModel(): model, criterion, optimizer, - converter, device, params, sign_list, + transform = None, + target_transform=None, + collate_fn = None, + pretrained = False, recorder = None): self.model = model.to(device) - self.model.apply(weigth_init) + self.criterion = criterion self.optimizer = optimizer - self.converter = converter + self.transform = transform + self.target_transform = target_transform self.device = device + sign_list = sorted(list(set(sign_list))) self.mapping = dict(zip(sign_list, list(range(len(sign_list))))) self.remapping = dict(zip(list(range(len(sign_list))), sign_list)) @@ -157,8 +161,15 @@ class BasicModel(): recorder = FakeRecorder() self.recorder = recorder + if pretrained: + # the paths of model, optimizer should be included in params + self.load(params.load_dir) + else: + self.model.apply(weigth_init) + self.save_interval = params.saveInterval self.params = params + self.collate_fn = collate_fn pass def _fit(self, data_loader, n_epoch, stop_loss): @@ -171,7 +182,10 @@ class BasicModel(): recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") if loss_value < min_loss: min_loss = loss_value - if loss_value < stop_loss: + if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0: + assert hasattr(self.params, 'save_dir') + self.save(self.params.save_dir) + if stop_loss is not None and loss_value < stop_loss: break recorder.print("Model fitted, minimal loss is ", min_loss) return loss_value @@ -181,30 +195,32 @@ class BasicModel(): def fit(self, data_loader = None, X = None, - y = None, - n_epoch = 100, - stop_loss = 0.001): + y = None): if data_loader is None: params = self.params + collate_fn = self.collate_fn + transform = self.transform + target_transform = self.target_transform + Y = self.str2ints(y) - train_dataset = XYDataset(X, Y) + train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \ shuffle=True, sampler=sampler, num_workers=int(params.workers), \ - collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) - return self._fit(data_loader, n_epoch, stop_loss) + collate_fn=collate_fn) + return self._fit(data_loader, params.n_epoch, params.stop_loss) def train_epoch(self, data_loader): - loss_avg = mutils.averager() + # loss_avg = mutils.averager() + self.model.train() + loss_value = 0 for i, data in enumerate(data_loader): X = data[0] Y = data[1] - cost = self.train_batch(X, Y) - loss_avg.add(cost) - - loss_value = float(loss_avg.val()) - loss_avg.reset() + loss = self.train_batch(X, Y) + loss_value += loss.item() + return loss_value def train_batch(self, X, Y): @@ -212,17 +228,10 @@ class BasicModel(): model = self.model criterion = self.criterion optimizer = self.optimizer - converter = self.converter device = self.device - - # set training mode - for p in model.parameters(): - p.requires_grad = True - model.train() # init training status - torch.autograd.set_detect_anomaly(True) - optimizer.zero_grad() + # torch.autograd.set_detect_anomaly(True) # model predict X = X.to(device) @@ -233,41 +242,39 @@ class BasicModel(): loss = criterion(pred_Y, Y) # back propagation and optimize + optimizer.zero_grad() loss.backward() optimizer.step() return loss def _predict(self, data_loader): model = self.model - criterion = self.criterion - converter = self.converter - params = self.params device = self.device - - for p in model.parameters(): - p.requires_grad = False model.eval() - - n_correct = 0 - - results = [] - for i, data in enumerate(data_loader): - X = data[0].to(device) - pred_Y = model(X) - results.append(pred_Y) + + with torch.no_grad(): + results = [] + for i, data in enumerate(data_loader): + X = data[0].to(device) + pred_Y = model(X) + results.append(pred_Y) return torch.cat(results, axis=0) def predict(self, data_loader = None, X = None, print_prefix = ""): - params = self.params if data_loader is None: + params = self.params + collate_fn = self.collate_fn + transform = self.transform + target_transform = self.target_transform + Y = [0] * len(X) - val_dataset = XYDataset(X, Y) + val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=False, sampler=sampler, num_workers=int(params.workers), \ - collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) + collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) @@ -275,14 +282,18 @@ class BasicModel(): return [self.remapping[int(y)] for y in Y] def predict_proba(self, data_loader = None, X = None, print_prefix = ""): - params = self.params if data_loader is None: + params = self.params + collate_fn = self.collate_fn + transform = self.transform + target_transform = self.target_transform + Y = [0] * len(X) - val_dataset = XYDataset(X, Y) + val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=False, sampler=sampler, num_workers=int(params.workers), \ - collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) + collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) @@ -292,45 +303,45 @@ class BasicModel(): model = self.model criterion = self.criterion recorder = self.recorder - converter = self.converter - params = self.params device = self.device recorder.print('Start val ', print_prefix) - - for p in model.parameters(): - p.requires_grad = False model.eval() n_correct = 0 pred_num = 0 - loss_avg = mutils.averager() - for i, data in enumerate(data_loader): - X = data[0].to(device) - Y = data[1].to(device) + loss_value = 0 + with torch.no_grad(): + for i, data in enumerate(data_loader): + X = data[0].to(device) + Y = data[1].to(device) - pred_Y = model(X) + pred_Y = model(X) - correct_num = sum(Y == pred_Y.argmax(axis=1)) - loss = criterion(pred_Y, Y) - loss_avg.add(loss) + correct_num = sum(Y == pred_Y.argmax(axis=1)) + loss = criterion(pred_Y, Y) + loss_value += loss.item() n_correct += correct_num pred_num += len(X) accuracy = float(n_correct) / float(pred_num) - recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_avg.val(), accuracy)) + recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy)) return accuracy def val(self, data_loader = None, X = None, y = None, print_prefix = ""): - params = self.params if data_loader is None: - y = self.str2ints(y) - val_dataset = XYDataset(X, y) + params = self.params + collate_fn = self.collate_fn + transform = self.transform + target_transform = self.target_transform + + Y = self.str2ints(y) + val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=True, sampler=sampler, num_workers=int(params.workers), \ - collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) + collate_fn=collate_fn) return self._val(data_loader, print_prefix) def score(self, data_loader = None, X = None, y = None, print_prefix = ""): diff --git a/models/lenet5.py b/models/lenet5.py index 676f553..9d5b054 100644 --- a/models/lenet5.py +++ b/models/lenet5.py @@ -34,7 +34,7 @@ class LeNet5(nn.Module): self.fc1 = nn.Linear(256, 120) self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 13) + self.fc3 = nn.Linear(84, 10) def forward(self, x): '''前向传播函数''' diff --git a/models/wabl_models.py b/models/wabl_models.py index 4b27418..15fe6e6 100644 --- a/models/wabl_models.py +++ b/models/wabl_models.py @@ -21,6 +21,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process.kernels import RBF +from models.basic_model import BasicModel import pickle as pk import random @@ -36,7 +37,6 @@ def merge_data(X): ret_X = list(chain(*X)) return ret_X, ret_mark - def reshape_data(Y, marks): begin_mark = 0 ret_Y = [] @@ -58,57 +58,26 @@ class WABLBasicModel: pass def predict(self, X): - if self.share: - data_X, marks = merge_data(X) - prob = self.cls_list[0].predict_proba(X = data_X) - cls = np.array(prob).argmax(axis = 1) + data_X, marks = merge_data(X) + prob = self.cls_list[0].predict_proba(X = data_X) + cls = np.array(prob).argmax(axis = 1) - prob = reshape_data(prob, marks) - cls = reshape_data(cls, marks) - else: - cls_result = [] - prob_result = [] - for i in range(self.code_len): - data_X = get_part_data(X, i) - tmp_prob = self.cls_list[i].predict_proba(X = data_X) - cls_result.append(np.array(tmp_prob).argmax(axis = 1)) - prob_result.append(tmp_prob) - - cls = list(zip(*cls_result)) - prob = list(zip(*prob_result)) + prob = reshape_data(prob, marks) + cls = reshape_data(cls, marks) return {"cls" : cls, "prob" : prob} def valid(self, X, Y): - if self.share: - data_X, _ = merge_data(X) - data_Y, _ = merge_data(Y) - score = self.cls_list[0].score(X = data_X, y = data_Y) - return score, [score] - else: - score_list = [] - for i in range(self.code_len): - data_X = get_part_data(X, i) - data_Y = get_part_data(Y, i) - score_list.append(self.cls_list[i].score(data_X, data_Y)) - - return sum(score_list) / len(score_list), score_list + data_X, _ = merge_data(X) + data_Y, _ = merge_data(Y) + score = self.cls_list[0].score(X = data_X, y = data_Y) + return score, [score] def train(self, X, Y): #self.label_lists = [] - if self.share: - data_X, _ = merge_data(X) - data_Y, _ = merge_data(Y) - self.cls_list[0].fit(X = data_X, y = data_Y) - else: - for i in range(self.code_len): - data_X = get_part_data(X, i) - data_Y = get_part_data(Y, i) - self.cls_list[i].fit(data_X, data_Y) - - def _set_label_lists(self, label_lists): - label_lists = [sorted(list(set(label_list))) for label_list in label_lists] - self.label_lists = label_lists + data_X, _ = merge_data(X) + data_Y, _ = merge_data(Y) + self.cls_list[0].fit(X = data_X, y = data_Y) class DecisionTree(WABLBasicModel): def __init__(self, code_len, label_lists, share = False): @@ -169,6 +138,11 @@ class CNN(WABLBasicModel): self.cls_list[i].fit(data_X, data_Y) #self.label_lists.append(sorted(list(set(data_Y)))) +class MyModel(WABLBasicModel): + def __init__(self, base_model): + + self.cls_list = [] + self.cls_list.append(base_model) if __name__ == "__main__": #data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk"