Browse Source

Merge branch 'trainer' of https://github.com/FengZiYjun/fastNLP into check

tags/v0.2.0^2
xuyige 6 years ago
parent
commit
ba7b17661c
4 changed files with 145 additions and 120 deletions
  1. +18
    -51
      fastNLP/core/optimizer.py
  2. +105
    -69
      fastNLP/core/trainer.py
  3. +21
    -0
      test/core/test_optimizer.py
  4. +1
    -0
      test/core/test_trainer.py

+ 18
- 51
fastNLP/core/optimizer.py View File

@@ -2,61 +2,28 @@ import torch




class Optimizer(object): class Optimizer(object):
"""Wrapper of optimizer from framework
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


1. Adam: lr (float), weight_decay (float)
2. AdaGrad
3. RMSProp
4. SGD: lr (float), momentum (float)


"""
class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def __init__(self, optimizer_name, **kwargs):
"""
:param optimizer_name: str, the name of the optimizer
:param kwargs: the arguments

"""
self.optim_name = optimizer_name
self.kwargs = kwargs

@property
def name(self):
"""The name of the optimizer.

:return: str
"""
return self.optim_name
def construct_from_pytorch(self, model_params):
if self.model_params is None:
self.model_params = model_params
return torch.optim.SGD(self.model_params, **self.settings)


@property
def params(self):
"""The arguments used to create the optimizer.


:return: dict of (str, *)
"""
return self.kwargs
class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
"""Construct a optimizer from framework over given model parameters."""

if self.optim_name in ["SGD", "sgd"]:
if "lr" in self.kwargs:
if "momentum" not in self.kwargs:
self.kwargs["momentum"] = 0
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
else:
raise ValueError("requires learning rate for SGD optimizer")

elif self.optim_name in ["adam", "Adam"]:
if "lr" in self.kwargs:
if "weight_decay" not in self.kwargs:
self.kwargs["weight_decay"] = 0
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
weight_decay=self.kwargs["weight_decay"])
else:
raise ValueError("requires learning rate for Adam optimizer")

else:
raise NotImplementedError

return optimizer
if self.model_params is None:
self.model_params = model_params
return torch.optim.Adam(self.model_params, **self.settings)

+ 105
- 69
fastNLP/core/trainer.py View File

@@ -1,39 +1,38 @@
import itertools
import os import os
import time import time
import warnings import warnings
from collections import defaultdict
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta


import torch import torch
from torch import nn
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet


from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True,
metric_key=None,
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -50,6 +49,13 @@ class Trainer(object):


# prepare evaluate # prepare evaluate
metrics = _prepare_metrics(metrics) metrics = _prepare_metrics(metrics)

# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key

# prepare loss # prepare loss
losser = _prepare_losser(losser) losser = _prepare_losser(losser)


@@ -67,12 +73,10 @@ class Trainer(object):
self.save_path = save_path self.save_path = save_path
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0
self.best_metric_indicator = None


self._model_device = model.parameters().__next__().device self._model_device = model.parameters().__next__().device


# TODO self._best_accuracy不能表现出当前的metric多种的情况

if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
@@ -102,7 +106,7 @@ class Trainer(object):
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda() self.model = self.model.cuda()


self.mode(self.model, is_test=False)
self._mode(self.model, is_test=False)


start = time.time() start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
@@ -112,7 +116,9 @@ class Trainer(object):
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass

return pass_func return pass_func

self._summary_writer = psudoSW() self._summary_writer = psudoSW()
else: else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
@@ -121,19 +127,20 @@ class Trainer(object):
epoch = 1 epoch = 1
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
as_numpy=False)


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
self._train_epoch(data_iterator, self.model, epoch, start)


# validate_every override validation at end of epochs # validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0: if self.dev_data and self.validate_every <= 0:
self.do_validation()
self._do_validation()
epoch += 1 epoch += 1
finally: finally:
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer


def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch. """Training process in one epoch.


kwargs should contain: kwargs should contain:
@@ -144,10 +151,10 @@ class Trainer(object):
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y) _move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x)
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if param.requires_grad: if param.requires_grad:
@@ -162,18 +169,19 @@ class Trainer(object):
print(print_output) print(print_output)


if self.validate_every > 0 and self.step % self.validate_every == 0: if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()
self._do_validation()


self.step += 1 self.step += 1


def do_validation(self):
def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items(): for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self.best_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)
if self.save_path is not None and self._better_eval_result(res):
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))


def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
@@ -185,20 +193,20 @@ class Trainer(object):
else: else:
model.train() model.train()


def update(self):
def _update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
self.optimizer.step() self.optimizer.step()


def data_forward(self, network, x):
def _data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict): if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def grad_backward(self, loss):
def _grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.


:param loss: a scalar where back-prop starts :param loss: a scalar where back-prop starts
@@ -208,7 +216,7 @@ class Trainer(object):
self.model.zero_grad() self.model.zero_grad()
loss.backward() loss.backward()


def get_loss(self, predict, truth):
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


:param predict: prediction dict, produced by model.forward :param predict: prediction dict, produced by model.forward
@@ -217,34 +225,59 @@ class Trainer(object):
""" """
return self.losser(predict, truth) return self.losser(predict, truth)


def save_model(self, model, model_name, only_param=False):
def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name) model_name = os.path.join(self.save_path, model_name)
if only_param: if only_param:
torch.save(model.state_dict(), model_name) torch.save(model.state_dict(), model_name)
else: else:
torch.save(model, model_name) torch.save(model, model_name)


def best_eval_result(self, metrics):
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.


:return: bool, True means current results on dev set is the best.
:return bool value: True means current results on dev set is the best.
""" """
if isinstance(metrics, tuple): if isinstance(metrics, tuple):
loss, metrics = metrics loss, metrics = metrics


if isinstance(metrics, dict): if isinstance(metrics, dict):
if len(metrics) == 1: if len(metrics) == 1:
accuracy = list(metrics.values())[0]
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else: else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better




DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
@@ -254,6 +287,7 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2



def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, dev_data=None,
check_level=WARNING_CHECK_LEVEL): check_level=WARNING_CHECK_LEVEL):
@@ -264,7 +298,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(model_devcie, batch_x, batch_y) _move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check # forward check
if batch_count==0:
if batch_count == 0:
_check_forward_error(model_func=model.forward, check_level=check_level, _check_forward_error(model_func=model.forward, check_level=check_level,
batch_x=batch_x) batch_x=batch_x)


@@ -285,17 +319,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0:
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])") f"should be torch.size([])")
loss.backward() loss.backward()
model.zero_grad() model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break break


if dev_data is not None: if dev_data is not None:
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1) batch_size=batch_size, verbose=-1)
tester.test() tester.test()


@@ -305,18 +339,18 @@ def _check_forward_error(model_func, check_level, batch_x):
_missing = '' _missing = ''
_unused = '' _unused = ''
func_signature = get_func_signature(model_func) func_signature = get_func_signature(model_func)
if len(check_res['missing'])!=0:
if len(check_res['missing']) != 0:
_missing = "Function {} misses {}, only provided with {}, " \ _missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing, ".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res['unused'])!=0:
list(batch_x.keys()))
if len(check_res['unused']) != 0:
if len(check_res.unused) > 1: if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused) _unused = "{} are not used ".format(check_res.unused)
else: else:
_unused = "{} is not used ".format(check_res.unused) _unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature) _unused += "in function {}.\n".format(func_signature)
if _missing: if _missing:
if len(_unused)>0 and STRICT_CHECK_LEVEL:
if len(_unused) > 0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) _error_str = "(1).{}\n(2).{}".format(_missing, _unused)
else: else:
_error_str = _missing _error_str = _missing
@@ -329,38 +363,40 @@ def _check_forward_error(model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL: elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused) warnings.warn(message=_unused)


def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):

def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
check_res = _check_arg_dict_list(func, [output, batch_y])
_missing = '' _missing = ''
_unused = '' _unused = ''
_duplicated = '' _duplicated = ''
func_signature = get_func_signature(func) func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func) prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
if len(check_res.missing) > 0:
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \ "{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused) > 0:
if len(check_res.unused) > 1: if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused) _unused = "{} are not used ".format(check_res.unused)
else: else:
_unused = "{} is not used ".format(check_res.unused) _unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature) _unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 0:
if len(check_res.duplicated) > 1: if len(check_res.duplicated) > 1:
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated, "them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature, func_signature,
check_res.duplicated, check_res.duplicated,
prev_func_signature) prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0)
if _number_errs > 0: if _number_errs > 0:
_error_strs = [] _error_strs = []
if _number_errs > 1: if _number_errs > 1:


+ 21
- 0
test/core/test_optimizer.py View File

@@ -0,0 +1,21 @@
import unittest

import torch

from fastNLP.core.optimizer import SGD


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)

def test_case_2(self):
with self.assertRaises(RuntimeError):
_ = SGD(0.001)

+ 1
- 0
test/core/test_trainer.py View File

@@ -4,3 +4,4 @@ import unittest
class TestTrainer(unittest.TestCase): class TestTrainer(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
pass pass


Loading…
Cancel
Save