Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

tags/v0.2.0^2
yh 6 years ago
parent
commit
f76851b982
3 changed files with 176 additions and 66 deletions
  1. +54
    -4
      fastNLP/core/optimizer.py
  2. +89
    -52
      fastNLP/core/trainer.py
  3. +33
    -10
      test/core/test_optimizer.py

+ 54
- 4
fastNLP/core/optimizer.py View File

@@ -3,14 +3,41 @@ import torch


class Optimizer(object): class Optimizer(object):
def __init__(self, model_params, **kwargs): def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params)))
self.model_params = model_params self.model_params = model_params
self.settings = kwargs self.settings = kwargs




class SGD(Optimizer): class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
def __init__(self, *args, **kwargs):
model_params, lr, momentum = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
# SGD()
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
# SGD(0.001)
lr = args[0]
elif hasattr(args[0], "__next__"):
# SGD(model.parameters()) args[0] is a generator
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9)
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "momentum") for key in kwargs):
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs))
lr = kwargs.get("lr", 0.01)
momentum = kwargs.get("momentum", 0.9)
else:
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
@@ -20,7 +47,30 @@ class SGD(Optimizer):




class Adam(Optimizer): class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
def __init__(self, *args, **kwargs):
model_params, lr, weight_decay = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
lr = args[0]
elif hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "weight_decay") for key in kwargs):
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs))
lr = kwargs.get("lr", 0.01)
weight_decay = kwargs.get("weight_decay", 0.9)
else:
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):


+ 89
- 52
fastNLP/core/trainer.py View File

@@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter
from torch import nn from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature



class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


@@ -33,6 +34,30 @@ class Trainer(object):
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, metric_key=None,
**kwargs): **kwargs):
"""

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase losser: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
:param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller

:param kwargs:

"""
super(Trainer, self).__init__() super(Trainer, self).__init__()


if not isinstance(train_data, DataSet): if not isinstance(train_data, DataSet):
@@ -56,12 +81,15 @@ class Trainer(object):
# increase_better is True. It means the exp result gets better if the indicator increases. # increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default. # It is true by default.
self.increase_better = False if metric_key[0] == "-" else True self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
if metric_key is not None:
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None


# prepare loss # prepare loss
losser = _prepare_losser(losser) losser = _prepare_losser(losser)


if check_code_level>-1:
if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level) check_level=check_code_level)


@@ -144,12 +172,13 @@ class Trainer(object):
del self._summary_writer del self._summary_writer


def _train_epoch(self, data_iterator, model, epoch, start): def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch.
"""


kwargs should contain:
- n_print: int, print training information every n steps.
- start: time.time(), the starting time of this step.
- epoch: int,
:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
""" """
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
@@ -188,7 +217,7 @@ class Trainer(object):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
:param bool is_test: whether in test mode or not.


""" """
if is_test: if is_test:
@@ -241,52 +270,29 @@ class Trainer(object):


:return bool value: True means current results on dev set is the best. :return bool value: True means current results on dev set is the best.
""" """
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else: else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else: else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better
is_better = False
return is_better




DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2



def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, dev_data=None,
check_level=0): check_level=0):
@@ -337,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# TODO 这里需要检查是否返回来的值是否是合理的 # TODO 这里需要检查是否返回来的值是否是合理的




def _check_eval_results(metrics, metric_key, metric_list):
# metrics: tester返回的结果
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化
# metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val

+ 33
- 10
test/core/test_optimizer.py View File

@@ -2,20 +2,43 @@ import unittest


import torch import torch


from fastNLP.core.optimizer import SGD
from fastNLP.core.optimizer import SGD, Adam




class TestOptim(unittest.TestCase): class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)
def test_SGD(self):
optim = SGD(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])


optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)
optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)


optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)
optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)


def test_case_2(self):
optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)

with self.assertRaises(RuntimeError):
_ = SGD("???")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = SGD(0.001)
_ = SGD(0.001, lr=0.002)
with self.assertRaises(RuntimeError):
_ = SGD(lr=0.009, shit=9000)

def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

Loading…
Cancel
Save