Browse Source

trainer迭代

tags/v0.2.0^2
yh 6 years ago
parent
commit
3d91f2f024
3 changed files with 148 additions and 50 deletions
  1. +10
    -8
      fastNLP/core/tester.py
  2. +80
    -37
      fastNLP/core/trainer.py
  3. +58
    -5
      fastNLP/core/utils.py

+ 10
- 8
fastNLP/core/tester.py View File

@@ -6,33 +6,34 @@ import torch
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, data, model, batch_size=16, use_cuda=False):
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
super(Tester, self).__init__() super(Tester, self).__init__()
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.data = data self.data = data
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda() self._model = model.cuda()
else: else:
self._model = model self._model = model
if hasattr(self._model, 'predict'): if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
if not callable(self._model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
self._predict_func = self._model.predict self._predict_func = self._model.predict
else: else:
self._predict_func = self._model self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches



def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
self.mode(network, is_test=True) self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list) output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)


@@ -48,9 +49,10 @@ class Tester(object):
output[k] = itertools.chain(*v) output[k] = itertools.chain(*v)
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v) truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
# args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args) eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
if self.verbose >= 0:
print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False) self.mode(network, is_test=False)
return eval_results return eval_results




+ 80
- 37
fastNLP/core/trainer.py View File

@@ -9,6 +9,7 @@ import shutil


from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
import torch import torch
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss from fastNLP.core.loss import Loss
@@ -21,12 +22,13 @@ from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs): **kwargs):
@@ -35,6 +37,8 @@ class Trainer(object):
self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
self.model = model self.model = model
self.losser = losser
self.metrics = metrics
self.n_epochs = int(n_epochs) self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
@@ -43,23 +47,22 @@ class Trainer(object):
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0 self._best_accuracy = 0


if need_check_code:
_check_code(dataset=train_data, model=model, dev_data=dev_data)


model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
# TODO check loss与metrics的类型



# TODO self._best_accuracy不能表现出当前的metric多种的情况

if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate

if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
data=self.dev_data, data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda) use_cuda=self.use_cuda)


@@ -71,6 +74,38 @@ class Trainer(object):


# print(self.__dict__) # print(self.__dict__)


def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
if not isinstance(train_data, DataSet):
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\
format(type(train_data)))
if not isinstance(model, nn.Module):
raise TypeError("The type of model must be torch.nn.Module, got {}.".\
format(type(model)))
if losser is not None:
# TODO change
if not isinstance(losser, None):
raise TypeError("The type of losser must be xxx, got {}.".\
format(type(losser)))

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check loss
if isinstance(losser, type):
self.losser = losser()
if not isinstance(self.losser, None):
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.')

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)


def train(self): def train(self):
"""Start Training. """Start Training.


@@ -171,6 +206,9 @@ class Trainer(object):
def data_forward(self, network, x): def data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict):

raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def grad_backward(self, loss): def grad_backward(self, loss):
@@ -231,11 +269,11 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法 # check get_loss 方法
model_name = model.__class__.__name__ model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
@@ -248,23 +286,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x) output = model(**refined_batch_x)
func_signature = get_func_signature(model.forward) func_signature = get_func_signature(model.forward)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")


# loss check # loss check
if batch_count == 0:
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss
# 需要保证output与batch_y是无歧义的?
# (1) output和batch_y长度为1
# (2) output和batch_y的key是和losser接受的完全一致
pass

loss = losser(output, batch_y)


# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
raise ValueError("The return value of {} should be torch.Tensor, but got {}.".
format(type(losser), type(loss)))
if len(loss.size())!=0: if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
type(losser), loss.size()
)) ))
loss.backward() loss.backward()
model.zero_grad() model.zero_grad()
@@ -272,26 +313,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
break break


if dev_data is not None: if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'."
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list) outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester


with torch.no_grad(): with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y) _syn_model_data(model, batch_x, batch_y)


if hasattr(model, 'predict'): if hasattr(model, 'predict'):
if not callable(model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
refined_batch_x = _build_args(model.predict, **batch_x) refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict prev_func = model.predict
output = prev_func(**refined_batch_x) output = prev_func(**refined_batch_x)
func_signature = get_func_signature(model.predict)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
else: else:
refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward prev_func = model.forward
output = prev_func(**refined_batch_x) output = prev_func(**refined_batch_x)
func_signature = get_func_signature(prev_func)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`")
for k, v in output.items(): for k, v in output.items():
outputs[k].append(v) outputs[k].append(v)
for k, v in batch_y.items(): for k, v in batch_y.items():
@@ -299,16 +343,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break break
for k, v in outputs.items(): for k, v in outputs.items():
outputs[k] = itertools.chain(*v)
outputs[k] = tuple(itertools.chain(*v))
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v)
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
truths[k] = tuple(itertools.chain(*v))
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug









def _check_forward_error(model_func, check_level, batch_x): def _check_forward_error(model_func, check_level, batch_x):


+ 58
- 5
fastNLP/core/utils.py View File

@@ -3,6 +3,7 @@ import inspect
import os import os
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple
import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)


@@ -95,7 +96,24 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args)) all_needed=list(all_args))


def get_func_signature(func): def get_func_signature(func):
# can only be used in function or class method
"""

Given a function or method, return its signature.
For example:
(1) function
def func(a, b='a', *args):
xxxx
get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
class Demo:
def __init__(self):
xxx
def forward(self, a, b='a', **args)
demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method
:return: str or None
"""
if inspect.ismethod(func): if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__ class_name = func.__self__.__class__.__name__
signature = inspect.signature(func) signature = inspect.signature(func)
@@ -113,10 +131,16 @@ def get_func_signature(func):
return signature_str return signature_str




# move data to model's device
import torch
def _syn_model_data(model, *args): def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param model:
:param args:
:return:
"""
if len(model.state_dict())==0:
raise ValueError("model has no parameter.")
device = model.parameters().__next__().device device = model.parameters().__next__().device
for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
@@ -124,4 +148,33 @@ def _syn_model_data(model, *args):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
arg[key] = value.to(device) arg[key] = value.to(device)
else: else:
raise ValueError("Only support dict type right now.")
raise TypeError("Only support `dict` type right now.")

def _prepare_metrics(metrics):
"""

Prepare list of Metric based on input
:param metrics:
:return:
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, None):
_metrics.append(metric)
else:
raise TypeError("The type of metric in metrics must be xxxx, not {}.".format(
type(), type(metric)
))
elif isinstance(metrics, None):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format(
type(metrics)
))

return _metrics


Loading…
Cancel
Save