|
@@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.batch import Batch |
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
|
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
from fastNLP.core.optimizer import Adam |
|
|
from fastNLP.core.optimizer import Adam |
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
|
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
|
|
|
from fastNLP.core.utils import CheckError |
|
|
from fastNLP.core.utils import CheckError |
|
|
from fastNLP.core.utils import _check_loss_evaluate |
|
|
|
|
|
from fastNLP.core.utils import _check_forward_error |
|
|
|
|
|
from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _build_args |
|
|
|
|
|
from fastNLP.core.utils import _check_forward_error |
|
|
|
|
|
from fastNLP.core.utils import _check_loss_evaluate |
|
|
from fastNLP.core.utils import _move_dict_value_to_device |
|
|
from fastNLP.core.utils import _move_dict_value_to_device |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
"""Main Training Loop |
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
@@ -33,6 +34,30 @@ class Trainer(object): |
|
|
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, |
|
|
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, |
|
|
metric_key=None, |
|
|
metric_key=None, |
|
|
**kwargs): |
|
|
**kwargs): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param DataSet train_data: the training data |
|
|
|
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
|
|
|
|
:param LossBase losser: a loss object |
|
|
|
|
|
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics |
|
|
|
|
|
:param int n_epochs: the number of training epochs |
|
|
|
|
|
:param int batch_size: batch size for training and validation |
|
|
|
|
|
:param int print_every: step interval to print next training information. Default: -1(no print). |
|
|
|
|
|
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). |
|
|
|
|
|
:param DataSet dev_data: the validation data |
|
|
|
|
|
:param use_cuda: |
|
|
|
|
|
:param str save_path: file path to save models |
|
|
|
|
|
:param Optimizer optimizer: an optimizer object |
|
|
|
|
|
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict. |
|
|
|
|
|
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one |
|
|
|
|
|
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets |
|
|
|
|
|
smaller, add a `-` character in front of the string. For example |
|
|
|
|
|
:: |
|
|
|
|
|
metric_key="-PPL" # language model gets better as perplexity gets smaller |
|
|
|
|
|
|
|
|
|
|
|
:param kwargs: |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
super(Trainer, self).__init__() |
|
|
super(Trainer, self).__init__() |
|
|
|
|
|
|
|
|
if not isinstance(train_data, DataSet): |
|
|
if not isinstance(train_data, DataSet): |
|
@@ -56,12 +81,15 @@ class Trainer(object): |
|
|
# increase_better is True. It means the exp result gets better if the indicator increases. |
|
|
# increase_better is True. It means the exp result gets better if the indicator increases. |
|
|
# It is true by default. |
|
|
# It is true by default. |
|
|
self.increase_better = False if metric_key[0] == "-" else True |
|
|
self.increase_better = False if metric_key[0] == "-" else True |
|
|
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key |
|
|
|
|
|
|
|
|
if metric_key is not None: |
|
|
|
|
|
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key |
|
|
|
|
|
else: |
|
|
|
|
|
self.metric_key = None |
|
|
|
|
|
|
|
|
# prepare loss |
|
|
# prepare loss |
|
|
losser = _prepare_losser(losser) |
|
|
losser = _prepare_losser(losser) |
|
|
|
|
|
|
|
|
if check_code_level>-1: |
|
|
|
|
|
|
|
|
if check_code_level > -1: |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
check_level=check_code_level) |
|
|
check_level=check_code_level) |
|
|
|
|
|
|
|
@@ -144,12 +172,13 @@ class Trainer(object): |
|
|
del self._summary_writer |
|
|
del self._summary_writer |
|
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, start): |
|
|
def _train_epoch(self, data_iterator, model, epoch, start): |
|
|
"""Training process in one epoch. |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
kwargs should contain: |
|
|
|
|
|
- n_print: int, print training information every n steps. |
|
|
|
|
|
- start: time.time(), the starting time of this step. |
|
|
|
|
|
- epoch: int, |
|
|
|
|
|
|
|
|
:param data_iterator: |
|
|
|
|
|
:param model: |
|
|
|
|
|
:param epoch: |
|
|
|
|
|
:param start: |
|
|
|
|
|
:return: |
|
|
""" |
|
|
""" |
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
@@ -188,7 +217,7 @@ class Trainer(object): |
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
|
|
|
|
|
|
:param model: a PyTorch model |
|
|
:param model: a PyTorch model |
|
|
:param is_test: bool, whether in test mode or not. |
|
|
|
|
|
|
|
|
:param bool is_test: whether in test mode or not. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
if is_test: |
|
|
if is_test: |
|
@@ -241,52 +270,29 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
:return bool value: True means current results on dev set is the best. |
|
|
:return bool value: True means current results on dev set is the best. |
|
|
""" |
|
|
""" |
|
|
if isinstance(metrics, tuple): |
|
|
|
|
|
loss, metrics = metrics |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(metrics, dict): |
|
|
|
|
|
if len(metrics) == 1: |
|
|
|
|
|
# only single metric, just use it |
|
|
|
|
|
metric_dict = list(metrics.values())[0] |
|
|
|
|
|
metrics_name = list(metrics.keys())[0] |
|
|
|
|
|
else: |
|
|
|
|
|
metrics_name = self.metrics[0].__class__.__name__ |
|
|
|
|
|
if metrics_name not in metrics: |
|
|
|
|
|
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") |
|
|
|
|
|
metric_dict = metrics[metrics_name] |
|
|
|
|
|
|
|
|
|
|
|
if len(metric_dict) == 1: |
|
|
|
|
|
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] |
|
|
|
|
|
elif len(metric_dict) > 1 and self.metric_key is None: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") |
|
|
|
|
|
else: |
|
|
|
|
|
# metric_key is set |
|
|
|
|
|
if self.metric_key not in metric_dict: |
|
|
|
|
|
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") |
|
|
|
|
|
indicator_val = metric_dict[self.metric_key] |
|
|
|
|
|
|
|
|
|
|
|
is_better = True |
|
|
|
|
|
if self.best_metric_indicator is None: |
|
|
|
|
|
# first-time validation |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
|
|
|
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) |
|
|
|
|
|
is_better = True |
|
|
|
|
|
if self.best_metric_indicator is None: |
|
|
|
|
|
# first-time validation |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
if self.increase_better is True: |
|
|
|
|
|
if indicator_val > self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
is_better = False |
|
|
else: |
|
|
else: |
|
|
if self.increase_better is True: |
|
|
|
|
|
if indicator_val > self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
is_better = False |
|
|
|
|
|
|
|
|
if indicator_val < self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
else: |
|
|
else: |
|
|
if indicator_val < self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
is_better = False |
|
|
|
|
|
return is_better |
|
|
|
|
|
|
|
|
is_better = False |
|
|
|
|
|
return is_better |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
dev_data=None, |
|
|
dev_data=None, |
|
|
check_level=0): |
|
|
check_level=0): |
|
@@ -337,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
# TODO 这里需要检查是否返回来的值是否是合理的 |
|
|
# TODO 这里需要检查是否返回来的值是否是合理的 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_eval_results(metrics, metric_key, metric_list): |
|
|
|
|
|
# metrics: tester返回的结果 |
|
|
|
|
|
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化 |
|
|
|
|
|
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 |
|
|
|
|
|
if isinstance(metrics, tuple): |
|
|
|
|
|
loss, metrics = metrics |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(metrics, dict): |
|
|
|
|
|
if len(metrics) == 1: |
|
|
|
|
|
# only single metric, just use it |
|
|
|
|
|
metric_dict = list(metrics.values())[0] |
|
|
|
|
|
metrics_name = list(metrics.keys())[0] |
|
|
|
|
|
else: |
|
|
|
|
|
metrics_name = metric_list[0].__class__.__name__ |
|
|
|
|
|
if metrics_name not in metrics: |
|
|
|
|
|
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") |
|
|
|
|
|
metric_dict = metrics[metrics_name] |
|
|
|
|
|
|
|
|
|
|
|
if len(metric_dict) == 1: |
|
|
|
|
|
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] |
|
|
|
|
|
elif len(metric_dict) > 1 and metric_key is None: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") |
|
|
|
|
|
else: |
|
|
|
|
|
# metric_key is set |
|
|
|
|
|
if metric_key not in metric_dict: |
|
|
|
|
|
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") |
|
|
|
|
|
indicator_val = metric_dict[metric_key] |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) |
|
|
|
|
|
return indicator_val |