Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

tags/v0.2.0^2
yh 6 years ago
parent
commit
f76851b982
3 changed files with 176 additions and 66 deletions
  1. +54
    -4
      fastNLP/core/optimizer.py
  2. +89
    -52
      fastNLP/core/trainer.py
  3. +33
    -10
      test/core/test_optimizer.py

+ 54
- 4
fastNLP/core/optimizer.py View File

@@ -3,14 +3,41 @@ import torch

class Optimizer(object):
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
def __init__(self, *args, **kwargs):
model_params, lr, momentum = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
# SGD()
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
# SGD(0.001)
lr = args[0]
elif hasattr(args[0], "__next__"):
# SGD(model.parameters()) args[0] is a generator
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9)
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "momentum") for key in kwargs):
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs))
lr = kwargs.get("lr", 0.01)
momentum = kwargs.get("momentum", 0.9)
else:
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)

def construct_from_pytorch(self, model_params):
@@ -20,7 +47,30 @@ class SGD(Optimizer):


class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
def __init__(self, *args, **kwargs):
model_params, lr, weight_decay = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
lr = args[0]
elif hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "weight_decay") for key in kwargs):
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs))
lr = kwargs.get("lr", 0.01)
weight_decay = kwargs.get("weight_decay", 0.9)
else:
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):


+ 89
- 52
fastNLP/core/trainer.py View File

@@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature


class Trainer(object):
"""Main Training Loop

@@ -33,6 +34,30 @@ class Trainer(object):
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs):
"""

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase losser: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
:param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller

:param kwargs:

"""
super(Trainer, self).__init__()

if not isinstance(train_data, DataSet):
@@ -56,12 +81,15 @@ class Trainer(object):
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
if metric_key is not None:
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None

# prepare loss
losser = _prepare_losser(losser)

if check_code_level>-1:
if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level)

@@ -144,12 +172,13 @@ class Trainer(object):
del self._summary_writer

def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch.
"""

kwargs should contain:
- n_print: int, print training information every n steps.
- start: time.time(), the starting time of this step.
- epoch: int,
:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
@@ -188,7 +217,7 @@ class Trainer(object):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
:param bool is_test: whether in test mode or not.

"""
if is_test:
@@ -241,52 +270,29 @@ class Trainer(object):

:return bool value: True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better
is_better = False
return is_better


DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2


def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=0):
@@ -337,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# TODO 这里需要检查是否返回来的值是否是合理的


def _check_eval_results(metrics, metric_key, metric_list):
# metrics: tester返回的结果
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化
# metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val

+ 33
- 10
test/core/test_optimizer.py View File

@@ -2,20 +2,43 @@ import unittest

import torch

from fastNLP.core.optimizer import SGD
from fastNLP.core.optimizer import SGD, Adam


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)
def test_SGD(self):
optim = SGD(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)
optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)
optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

def test_case_2(self):
optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)

with self.assertRaises(RuntimeError):
_ = SGD("???")
with self.assertRaises(RuntimeError):
_ = SGD(0.001)
_ = SGD(0.001, lr=0.002)
with self.assertRaises(RuntimeError):
_ = SGD(lr=0.009, shit=9000)

def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

Loading…
Cancel
Save