Browse Source

更新embed_loader:

* 添加fast_load_embedding方法,用vocab的词索引pre-trained中的embedding
* 如果vocab有词没出现在pre-train中,从已有embedding中正态采样

Update embed_loader:
* add fast_load_embedding method, to index pre-trained embedding with words in Vocab
* If words in Vocab are not exist in pre-trained, sample them from normal distribution computed by current embeddings
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
e6864ea7e0
1 changed files with 98 additions and 61 deletions
  1. +98
    -61
      fastNLP/core/trainer.py

+ 98
- 61
fastNLP/core/trainer.py View File

@@ -1,39 +1,38 @@
import itertools
import os
import time
import warnings
from collections import defaultdict
from datetime import datetime
from datetime import timedelta

import torch
from torch import nn
from tensorboardX import SummaryWriter
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError

class Trainer(object):
"""Main Training Loop

"""
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
metric_key=None,
**kwargs):
super(Trainer, self).__init__()

@@ -50,6 +49,13 @@ class Trainer(object):

# prepare evaluate
metrics = _prepare_metrics(metrics)

# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key

# prepare loss
losser = _prepare_losser(losser)

@@ -67,7 +73,7 @@ class Trainer(object):
self.save_path = save_path
self.print_every = int(print_every)
self.validate_every = int(validate_every)
self._best_accuracy = 0
self.best_metric_indicator = None

self._model_device = model.parameters().__next__().device

@@ -102,7 +108,7 @@ class Trainer(object):
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()

self.mode(self.model, is_test=False)
self._mode(self.model, is_test=False)

start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
@@ -112,7 +118,9 @@ class Trainer(object):
def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass

return pass_func

self._summary_writer = psudoSW()
else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
@@ -121,13 +129,14 @@ class Trainer(object):
epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
as_numpy=False)

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self.do_validation()
self._do_validation()
epoch += 1
finally:
self._summary_writer.close()
@@ -144,10 +153,10 @@ class Trainer(object):
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x)
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
@@ -162,18 +171,18 @@ class Trainer(object):
print(print_output)

if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()
self._do_validation()

self.step += 1

def do_validation(self):
def _do_validation(self):
res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self.best_eval_result(res):
if self.save_path is not None and self._better_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)

def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
@@ -185,20 +194,20 @@ class Trainer(object):
else:
model.train()

def update(self):
def _update(self):
"""Perform weight update on a model.

"""
self.optimizer.step()

def data_forward(self, network, x):
def _data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x)
if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y

def grad_backward(self, loss):
def _grad_backward(self, loss):
"""Compute gradient with link rules.

:param loss: a scalar where back-prop starts
@@ -208,7 +217,7 @@ class Trainer(object):
self.model.zero_grad()
loss.backward()

def get_loss(self, predict, truth):
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.

:param predict: prediction dict, produced by model.forward
@@ -224,27 +233,52 @@ class Trainer(object):
else:
torch.save(model, model_name)

def best_eval_result(self, metrics):
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

:return: bool, True means current results on dev set is the best.
:return bool value: True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better


DEFAULT_CHECK_BATCH_SIZE = 2
@@ -254,6 +288,7 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2


def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
@@ -264,7 +299,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check
if batch_count==0:
if batch_count == 0:
_check_forward_error(model_func=model.forward, check_level=check_level,
batch_x=batch_x)

@@ -285,17 +320,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0:
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break

if dev_data is not None:
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
tester.test()

@@ -305,18 +340,18 @@ def _check_forward_error(model_func, check_level, batch_x):
_missing = ''
_unused = ''
func_signature = get_func_signature(model_func)
if len(check_res['missing'])!=0:
if len(check_res['missing']) != 0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res['unused'])!=0:
list(batch_x.keys()))
if len(check_res['unused']) != 0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if len(_unused)>0 and STRICT_CHECK_LEVEL:
if len(_unused) > 0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
else:
_error_str = _missing
@@ -329,38 +364,40 @@ def _check_forward_error(model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):

def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
check_res = _check_arg_dict_list(func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
if len(check_res.missing) > 0:
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused) > 0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 0:
if len(check_res.duplicated) > 1:
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0)
if _number_errs > 0:
_error_strs = []
if _number_errs > 1:


Loading…
Cancel
Save