Browse Source

modify wabl_models.py and basic_model.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
7f29b79aee
3 changed files with 94 additions and 109 deletions
  1. +75
    -64
      models/basic_model.py
  2. +1
    -1
      models/lenet5.py
  3. +18
    -44
      models/wabl_models.py

+ 75
- 64
models/basic_model.py View File

@@ -18,8 +18,6 @@ from torch.autograd import Variable
from torch.utils.data import Dataset
import torchvision

import utils.utils as mutils

import os
from multiprocessing import Pool

@@ -121,6 +119,7 @@ class FakeRecorder():

from torch.nn import init
from torch import nn

def weigth_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight.data)
@@ -137,18 +136,23 @@ class BasicModel():
model,
criterion,
optimizer,
converter,
device,
params,
sign_list,
transform = None,
target_transform=None,
collate_fn = None,
pretrained = False,
recorder = None):

self.model = model.to(device)
self.model.apply(weigth_init)
self.criterion = criterion
self.optimizer = optimizer
self.converter = converter
self.transform = transform
self.target_transform = target_transform
self.device = device

sign_list = sorted(list(set(sign_list)))
self.mapping = dict(zip(sign_list, list(range(len(sign_list)))))
self.remapping = dict(zip(list(range(len(sign_list))), sign_list))
@@ -157,8 +161,15 @@ class BasicModel():
recorder = FakeRecorder()
self.recorder = recorder

if pretrained:
# the paths of model, optimizer should be included in params
self.load(params.load_dir)
else:
self.model.apply(weigth_init)

self.save_interval = params.saveInterval
self.params = params
self.collate_fn = collate_fn
pass

def _fit(self, data_loader, n_epoch, stop_loss):
@@ -171,7 +182,10 @@ class BasicModel():
recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
if loss_value < min_loss:
min_loss = loss_value
if loss_value < stop_loss:
if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0:
assert hasattr(self.params, 'save_dir')
self.save(self.params.save_dir)
if stop_loss is not None and loss_value < stop_loss:
break
recorder.print("Model fitted, minimal loss is ", min_loss)
return loss_value
@@ -181,30 +195,32 @@ class BasicModel():

def fit(self, data_loader = None,
X = None,
y = None,
n_epoch = 100,
stop_loss = 0.001):
y = None):
if data_loader is None:
params = self.params
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = self.str2ints(y)
train_dataset = XYDataset(X, Y)
train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
sampler = None
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
shuffle=True, sampler=sampler, num_workers=int(params.workers), \
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
return self._fit(data_loader, n_epoch, stop_loss)
collate_fn=collate_fn)
return self._fit(data_loader, params.n_epoch, params.stop_loss)

def train_epoch(self, data_loader):
loss_avg = mutils.averager()
# loss_avg = mutils.averager()
self.model.train()

loss_value = 0
for i, data in enumerate(data_loader):
X = data[0]
Y = data[1]
cost = self.train_batch(X, Y)
loss_avg.add(cost)
loss_value = float(loss_avg.val())
loss_avg.reset()
loss = self.train_batch(X, Y)
loss_value += loss.item()

return loss_value

def train_batch(self, X, Y):
@@ -212,17 +228,10 @@ class BasicModel():
model = self.model
criterion = self.criterion
optimizer = self.optimizer
converter = self.converter
device = self.device

# set training mode
for p in model.parameters():
p.requires_grad = True
model.train()
# init training status
torch.autograd.set_detect_anomaly(True)
optimizer.zero_grad()
# torch.autograd.set_detect_anomaly(True)

# model predict
X = X.to(device)
@@ -233,41 +242,39 @@ class BasicModel():
loss = criterion(pred_Y, Y)

# back propagation and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss

def _predict(self, data_loader):
model = self.model
criterion = self.criterion
converter = self.converter
params = self.params
device = self.device
for p in model.parameters():
p.requires_grad = False
model.eval()
n_correct = 0
results = []
for i, data in enumerate(data_loader):
X = data[0].to(device)
pred_Y = model(X)
results.append(pred_Y)

with torch.no_grad():
results = []
for i, data in enumerate(data_loader):
X = data[0].to(device)
pred_Y = model(X)
results.append(pred_Y)
return torch.cat(results, axis=0)

def predict(self, data_loader = None, X = None, print_prefix = ""):
params = self.params
if data_loader is None:
params = self.params
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = [0] * len(X)
val_dataset = XYDataset(X, Y)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
shuffle=False, sampler=sampler, num_workers=int(params.workers), \
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
collate_fn=collate_fn)

recorder = self.recorder
recorder.print('Start Predict ', print_prefix)
@@ -275,14 +282,18 @@ class BasicModel():
return [self.remapping[int(y)] for y in Y]

def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
params = self.params
if data_loader is None:
params = self.params
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = [0] * len(X)
val_dataset = XYDataset(X, Y)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
shuffle=False, sampler=sampler, num_workers=int(params.workers), \
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
collate_fn=collate_fn)

recorder = self.recorder
recorder.print('Start Predict ', print_prefix)
@@ -292,45 +303,45 @@ class BasicModel():
model = self.model
criterion = self.criterion
recorder = self.recorder
converter = self.converter
params = self.params
device = self.device
recorder.print('Start val ', print_prefix)
for p in model.parameters():
p.requires_grad = False
model.eval()
n_correct = 0
pred_num = 0
loss_avg = mutils.averager()
for i, data in enumerate(data_loader):
X = data[0].to(device)
Y = data[1].to(device)
loss_value = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
X = data[0].to(device)
Y = data[1].to(device)

pred_Y = model(X)
pred_Y = model(X)

correct_num = sum(Y == pred_Y.argmax(axis=1))
loss = criterion(pred_Y, Y)
loss_avg.add(loss)
correct_num = sum(Y == pred_Y.argmax(axis=1))
loss = criterion(pred_Y, Y)
loss_value += loss.item()

n_correct += correct_num
pred_num += len(X)

accuracy = float(n_correct) / float(pred_num)
recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_avg.val(), accuracy))
recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy))
return accuracy

def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
params = self.params
if data_loader is None:
y = self.str2ints(y)
val_dataset = XYDataset(X, y)
params = self.params
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = self.str2ints(y)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
shuffle=True, sampler=sampler, num_workers=int(params.workers), \
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
collate_fn=collate_fn)
return self._val(data_loader, print_prefix)

def score(self, data_loader = None, X = None, y = None, print_prefix = ""):


+ 1
- 1
models/lenet5.py View File

@@ -34,7 +34,7 @@ class LeNet5(nn.Module):

self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 13)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
'''前向传播函数'''


+ 18
- 44
models/wabl_models.py View File

@@ -21,6 +21,7 @@ from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from models.basic_model import BasicModel

import pickle as pk
import random
@@ -36,7 +37,6 @@ def merge_data(X):
ret_X = list(chain(*X))
return ret_X, ret_mark


def reshape_data(Y, marks):
begin_mark = 0
ret_Y = []
@@ -58,57 +58,26 @@ class WABLBasicModel:
pass

def predict(self, X):
if self.share:
data_X, marks = merge_data(X)
prob = self.cls_list[0].predict_proba(X = data_X)
cls = np.array(prob).argmax(axis = 1)
data_X, marks = merge_data(X)
prob = self.cls_list[0].predict_proba(X = data_X)
cls = np.array(prob).argmax(axis = 1)

prob = reshape_data(prob, marks)
cls = reshape_data(cls, marks)
else:
cls_result = []
prob_result = []
for i in range(self.code_len):
data_X = get_part_data(X, i)
tmp_prob = self.cls_list[i].predict_proba(X = data_X)
cls_result.append(np.array(tmp_prob).argmax(axis = 1))
prob_result.append(tmp_prob)

cls = list(zip(*cls_result))
prob = list(zip(*prob_result))
prob = reshape_data(prob, marks)
cls = reshape_data(cls, marks)

return {"cls" : cls, "prob" : prob}

def valid(self, X, Y):
if self.share:
data_X, _ = merge_data(X)
data_Y, _ = merge_data(Y)
score = self.cls_list[0].score(X = data_X, y = data_Y)
return score, [score]
else:
score_list = []
for i in range(self.code_len):
data_X = get_part_data(X, i)
data_Y = get_part_data(Y, i)
score_list.append(self.cls_list[i].score(data_X, data_Y))

return sum(score_list) / len(score_list), score_list
data_X, _ = merge_data(X)
data_Y, _ = merge_data(Y)
score = self.cls_list[0].score(X = data_X, y = data_Y)
return score, [score]

def train(self, X, Y):
#self.label_lists = []
if self.share:
data_X, _ = merge_data(X)
data_Y, _ = merge_data(Y)
self.cls_list[0].fit(X = data_X, y = data_Y)
else:
for i in range(self.code_len):
data_X = get_part_data(X, i)
data_Y = get_part_data(Y, i)
self.cls_list[i].fit(data_X, data_Y)

def _set_label_lists(self, label_lists):
label_lists = [sorted(list(set(label_list))) for label_list in label_lists]
self.label_lists = label_lists
data_X, _ = merge_data(X)
data_Y, _ = merge_data(Y)
self.cls_list[0].fit(X = data_X, y = data_Y)

class DecisionTree(WABLBasicModel):
def __init__(self, code_len, label_lists, share = False):
@@ -169,6 +138,11 @@ class CNN(WABLBasicModel):
self.cls_list[i].fit(data_X, data_Y)
#self.label_lists.append(sorted(list(set(data_Y))))

class MyModel(WABLBasicModel):
def __init__(self, base_model):

self.cls_list = []
self.cls_list.append(base_model)

if __name__ == "__main__":
#data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk"


Loading…
Cancel
Save