|
|
@@ -18,8 +18,6 @@ from torch.autograd import Variable |
|
|
|
from torch.utils.data import Dataset |
|
|
|
import torchvision |
|
|
|
|
|
|
|
import utils.utils as mutils |
|
|
|
|
|
|
|
import os |
|
|
|
from multiprocessing import Pool |
|
|
|
|
|
|
@@ -121,6 +119,7 @@ class FakeRecorder(): |
|
|
|
|
|
|
|
from torch.nn import init |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
def weigth_init(m): |
|
|
|
if isinstance(m, nn.Conv2d): |
|
|
|
init.xavier_uniform_(m.weight.data) |
|
|
@@ -137,18 +136,23 @@ class BasicModel(): |
|
|
|
model, |
|
|
|
criterion, |
|
|
|
optimizer, |
|
|
|
converter, |
|
|
|
device, |
|
|
|
params, |
|
|
|
sign_list, |
|
|
|
transform = None, |
|
|
|
target_transform=None, |
|
|
|
collate_fn = None, |
|
|
|
pretrained = False, |
|
|
|
recorder = None): |
|
|
|
|
|
|
|
self.model = model.to(device) |
|
|
|
self.model.apply(weigth_init) |
|
|
|
|
|
|
|
self.criterion = criterion |
|
|
|
self.optimizer = optimizer |
|
|
|
self.converter = converter |
|
|
|
self.transform = transform |
|
|
|
self.target_transform = target_transform |
|
|
|
self.device = device |
|
|
|
|
|
|
|
sign_list = sorted(list(set(sign_list))) |
|
|
|
self.mapping = dict(zip(sign_list, list(range(len(sign_list))))) |
|
|
|
self.remapping = dict(zip(list(range(len(sign_list))), sign_list)) |
|
|
@@ -157,8 +161,15 @@ class BasicModel(): |
|
|
|
recorder = FakeRecorder() |
|
|
|
self.recorder = recorder |
|
|
|
|
|
|
|
if pretrained: |
|
|
|
# the paths of model, optimizer should be included in params |
|
|
|
self.load(params.load_dir) |
|
|
|
else: |
|
|
|
self.model.apply(weigth_init) |
|
|
|
|
|
|
|
self.save_interval = params.saveInterval |
|
|
|
self.params = params |
|
|
|
self.collate_fn = collate_fn |
|
|
|
pass |
|
|
|
|
|
|
|
def _fit(self, data_loader, n_epoch, stop_loss): |
|
|
@@ -171,7 +182,10 @@ class BasicModel(): |
|
|
|
recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") |
|
|
|
if loss_value < min_loss: |
|
|
|
min_loss = loss_value |
|
|
|
if loss_value < stop_loss: |
|
|
|
if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0: |
|
|
|
assert hasattr(self.params, 'save_dir') |
|
|
|
self.save(self.params.save_dir) |
|
|
|
if stop_loss is not None and loss_value < stop_loss: |
|
|
|
break |
|
|
|
recorder.print("Model fitted, minimal loss is ", min_loss) |
|
|
|
return loss_value |
|
|
@@ -181,30 +195,32 @@ class BasicModel(): |
|
|
|
|
|
|
|
def fit(self, data_loader = None, |
|
|
|
X = None, |
|
|
|
y = None, |
|
|
|
n_epoch = 100, |
|
|
|
stop_loss = 0.001): |
|
|
|
y = None): |
|
|
|
if data_loader is None: |
|
|
|
params = self.params |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
target_transform = self.target_transform |
|
|
|
|
|
|
|
Y = self.str2ints(y) |
|
|
|
train_dataset = XYDataset(X, Y) |
|
|
|
train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) |
|
|
|
sampler = None |
|
|
|
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \ |
|
|
|
shuffle=True, sampler=sampler, num_workers=int(params.workers), \ |
|
|
|
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) |
|
|
|
return self._fit(data_loader, n_epoch, stop_loss) |
|
|
|
collate_fn=collate_fn) |
|
|
|
return self._fit(data_loader, params.n_epoch, params.stop_loss) |
|
|
|
|
|
|
|
def train_epoch(self, data_loader): |
|
|
|
loss_avg = mutils.averager() |
|
|
|
# loss_avg = mutils.averager() |
|
|
|
self.model.train() |
|
|
|
|
|
|
|
loss_value = 0 |
|
|
|
for i, data in enumerate(data_loader): |
|
|
|
X = data[0] |
|
|
|
Y = data[1] |
|
|
|
cost = self.train_batch(X, Y) |
|
|
|
loss_avg.add(cost) |
|
|
|
|
|
|
|
loss_value = float(loss_avg.val()) |
|
|
|
loss_avg.reset() |
|
|
|
loss = self.train_batch(X, Y) |
|
|
|
loss_value += loss.item() |
|
|
|
|
|
|
|
return loss_value |
|
|
|
|
|
|
|
def train_batch(self, X, Y): |
|
|
@@ -212,17 +228,10 @@ class BasicModel(): |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
optimizer = self.optimizer |
|
|
|
converter = self.converter |
|
|
|
device = self.device |
|
|
|
|
|
|
|
# set training mode |
|
|
|
for p in model.parameters(): |
|
|
|
p.requires_grad = True |
|
|
|
model.train() |
|
|
|
|
|
|
|
# init training status |
|
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
optimizer.zero_grad() |
|
|
|
# torch.autograd.set_detect_anomaly(True) |
|
|
|
|
|
|
|
# model predict |
|
|
|
X = X.to(device) |
|
|
@@ -233,41 +242,39 @@ class BasicModel(): |
|
|
|
loss = criterion(pred_Y, Y) |
|
|
|
|
|
|
|
# back propagation and optimize |
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
return loss |
|
|
|
|
|
|
|
def _predict(self, data_loader): |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
converter = self.converter |
|
|
|
params = self.params |
|
|
|
device = self.device |
|
|
|
|
|
|
|
for p in model.parameters(): |
|
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
n_correct = 0 |
|
|
|
|
|
|
|
results = [] |
|
|
|
for i, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
pred_Y = model(X) |
|
|
|
results.append(pred_Y) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
results = [] |
|
|
|
for i, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
pred_Y = model(X) |
|
|
|
results.append(pred_Y) |
|
|
|
|
|
|
|
return torch.cat(results, axis=0) |
|
|
|
|
|
|
|
def predict(self, data_loader = None, X = None, print_prefix = ""): |
|
|
|
params = self.params |
|
|
|
if data_loader is None: |
|
|
|
params = self.params |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
target_transform = self.target_transform |
|
|
|
|
|
|
|
Y = [0] * len(X) |
|
|
|
val_dataset = XYDataset(X, Y) |
|
|
|
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) |
|
|
|
sampler = None |
|
|
|
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ |
|
|
|
shuffle=False, sampler=sampler, num_workers=int(params.workers), \ |
|
|
|
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) |
|
|
|
collate_fn=collate_fn) |
|
|
|
|
|
|
|
recorder = self.recorder |
|
|
|
recorder.print('Start Predict ', print_prefix) |
|
|
@@ -275,14 +282,18 @@ class BasicModel(): |
|
|
|
return [self.remapping[int(y)] for y in Y] |
|
|
|
|
|
|
|
def predict_proba(self, data_loader = None, X = None, print_prefix = ""): |
|
|
|
params = self.params |
|
|
|
if data_loader is None: |
|
|
|
params = self.params |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
target_transform = self.target_transform |
|
|
|
|
|
|
|
Y = [0] * len(X) |
|
|
|
val_dataset = XYDataset(X, Y) |
|
|
|
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) |
|
|
|
sampler = None |
|
|
|
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ |
|
|
|
shuffle=False, sampler=sampler, num_workers=int(params.workers), \ |
|
|
|
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) |
|
|
|
collate_fn=collate_fn) |
|
|
|
|
|
|
|
recorder = self.recorder |
|
|
|
recorder.print('Start Predict ', print_prefix) |
|
|
@@ -292,45 +303,45 @@ class BasicModel(): |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
recorder = self.recorder |
|
|
|
converter = self.converter |
|
|
|
params = self.params |
|
|
|
device = self.device |
|
|
|
recorder.print('Start val ', print_prefix) |
|
|
|
|
|
|
|
for p in model.parameters(): |
|
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
n_correct = 0 |
|
|
|
pred_num = 0 |
|
|
|
loss_avg = mutils.averager() |
|
|
|
for i, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
Y = data[1].to(device) |
|
|
|
loss_value = 0 |
|
|
|
with torch.no_grad(): |
|
|
|
for i, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
Y = data[1].to(device) |
|
|
|
|
|
|
|
pred_Y = model(X) |
|
|
|
pred_Y = model(X) |
|
|
|
|
|
|
|
correct_num = sum(Y == pred_Y.argmax(axis=1)) |
|
|
|
loss = criterion(pred_Y, Y) |
|
|
|
loss_avg.add(loss) |
|
|
|
correct_num = sum(Y == pred_Y.argmax(axis=1)) |
|
|
|
loss = criterion(pred_Y, Y) |
|
|
|
loss_value += loss.item() |
|
|
|
|
|
|
|
n_correct += correct_num |
|
|
|
pred_num += len(X) |
|
|
|
|
|
|
|
accuracy = float(n_correct) / float(pred_num) |
|
|
|
recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_avg.val(), accuracy)) |
|
|
|
recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy)) |
|
|
|
return accuracy |
|
|
|
|
|
|
|
def val(self, data_loader = None, X = None, y = None, print_prefix = ""): |
|
|
|
params = self.params |
|
|
|
if data_loader is None: |
|
|
|
y = self.str2ints(y) |
|
|
|
val_dataset = XYDataset(X, y) |
|
|
|
params = self.params |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
target_transform = self.target_transform |
|
|
|
|
|
|
|
Y = self.str2ints(y) |
|
|
|
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) |
|
|
|
sampler = None |
|
|
|
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ |
|
|
|
shuffle=True, sampler=sampler, num_workers=int(params.workers), \ |
|
|
|
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) |
|
|
|
collate_fn=collate_fn) |
|
|
|
return self._val(data_loader, print_prefix) |
|
|
|
|
|
|
|
def score(self, data_loader = None, X = None, y = None, print_prefix = ""): |
|
|
|