Browse Source

trainer迭代

tags/v0.2.0^2
yh 6 years ago
parent
commit
3d91f2f024
3 changed files with 148 additions and 50 deletions
  1. +10
    -8
      fastNLP/core/tester.py
  2. +80
    -37
      fastNLP/core/trainer.py
  3. +58
    -5
      fastNLP/core/utils.py

+ 10
- 8
fastNLP/core/tester.py View File

@@ -6,33 +6,34 @@ import torch
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature

class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, data, model, batch_size=16, use_cuda=False):
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
super(Tester, self).__init__()
self.use_cuda = use_cuda
self.data = data
self.batch_size = batch_size
self.verbose = verbose
if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda()
else:
self._model = model
if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
if not callable(self._model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
self._predict_func = self._model.predict
else:
self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches


def test(self):
# turn on the testing mode; clean up the history
network = self._model
self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)

@@ -48,9 +49,10 @@ class Tester(object):
output[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
# args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
if self.verbose >= 0:
print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
return eval_results



+ 80
- 37
fastNLP/core/trainer.py View File

@@ -9,6 +9,7 @@ import shutil

from tensorboardX import SummaryWriter
import torch
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
@@ -21,12 +22,13 @@ from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet

class Trainer(object):
"""Main Training Loop

"""
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
@@ -35,6 +37,8 @@ class Trainer(object):
self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
self.model = model
self.losser = losser
self.metrics = metrics
self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda)
@@ -43,23 +47,22 @@ class Trainer(object):
self.validate_every = int(validate_every)
self._best_accuracy = 0

if need_check_code:
_check_code(dataset=train_data, model=model, dev_data=dev_data)

model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
# TODO check loss与metrics的类型



# TODO self._best_accuracy不能表现出当前的metric多种的情况

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate

if self.dev_data is not None:
self.tester = Tester(model=self.model,
data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size,
use_cuda=self.use_cuda)

@@ -71,6 +74,38 @@ class Trainer(object):

# print(self.__dict__)

def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
if not isinstance(train_data, DataSet):
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\
format(type(train_data)))
if not isinstance(model, nn.Module):
raise TypeError("The type of model must be torch.nn.Module, got {}.".\
format(type(model)))
if losser is not None:
# TODO change
if not isinstance(losser, None):
raise TypeError("The type of losser must be xxx, got {}.".\
format(type(losser)))

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check loss
if isinstance(losser, type):
self.losser = losser()
if not isinstance(self.losser, None):
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.')

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)


def train(self):
"""Start Training.

@@ -171,6 +206,9 @@ class Trainer(object):
def data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x)
if not isinstance(y, dict):

raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y

def grad_backward(self, loss):
@@ -231,11 +269,11 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法
model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))

batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
@@ -248,23 +286,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
func_signature = get_func_signature(model.forward)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")

# loss check
if batch_count == 0:
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss
# 需要保证output与batch_y是无歧义的?
# (1) output和batch_y长度为1
# (2) output和batch_y的key是和losser接受的完全一致
pass

loss = losser(output, batch_y)

# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
raise ValueError("The return value of {} should be torch.Tensor, but got {}.".
format(type(losser), type(loss)))
if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
type(losser), loss.size()
))
loss.backward()
model.zero_grad()
@@ -272,26 +313,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
break

if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'."
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester


with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

if hasattr(model, 'predict'):
if not callable(model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(model.predict)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
else:
refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(prev_func)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`")
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
@@ -299,16 +343,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
for k, v in outputs.items():
outputs[k] = itertools.chain(*v)
outputs[k] = tuple(itertools.chain(*v))
for k, v in truths.items():
truths[k] = itertools.chain(*v)
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
truths[k] = tuple(itertools.chain(*v))
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug







def _check_forward_error(model_func, check_level, batch_x):


+ 58
- 5
fastNLP/core/utils.py View File

@@ -3,6 +3,7 @@ import inspect
import os
from collections import Counter
from collections import namedtuple
import torch

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)

@@ -95,7 +96,24 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args))

def get_func_signature(func):
# can only be used in function or class method
"""

Given a function or method, return its signature.
For example:
(1) function
def func(a, b='a', *args):
xxxx
get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
class Demo:
def __init__(self):
xxx
def forward(self, a, b='a', **args)
demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method
:return: str or None
"""
if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__
signature = inspect.signature(func)
@@ -113,10 +131,16 @@ def get_func_signature(func):
return signature_str


# move data to model's device
import torch
def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param model:
:param args:
:return:
"""
if len(model.state_dict())==0:
raise ValueError("model has no parameter.")
device = model.parameters().__next__().device
for arg in args:
if isinstance(arg, dict):
@@ -124,4 +148,33 @@ def _syn_model_data(model, *args):
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise ValueError("Only support dict type right now.")
raise TypeError("Only support `dict` type right now.")

def _prepare_metrics(metrics):
"""

Prepare list of Metric based on input
:param metrics:
:return:
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, None):
_metrics.append(metric)
else:
raise TypeError("The type of metric in metrics must be xxxx, not {}.".format(
type(), type(metric)
))
elif isinstance(metrics, None):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format(
type(metrics)
))

return _metrics


Loading…
Cancel
Save