@@ -73,6 +73,29 @@ class NewLoss(LossBase): | |||
raise RuntimeError("") | |||
class LossInForward(LossBase): | |||
def __init__(self, loss_key='loss'): | |||
super().__init__() | |||
self.loss_key = loss_key | |||
def get_loss(self, *args, **kwargs): | |||
pass | |||
def __call__(self, output_dict, predict_dict): | |||
pass | |||
def _prepare_losser(losser): | |||
if losser is None: | |||
losser = LossInForward() | |||
return losser | |||
elif isinstance(losser, LossBase): | |||
return losser | |||
else: | |||
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}") | |||
def squash(predict, truth, **kwargs): | |||
'''To reshape tensors in order to fit Loss functions in pytorch | |||
@@ -1,8 +1,136 @@ | |||
import warnings | |||
import inspect | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _build_args | |||
class MetricBase(object): | |||
def __init__(self): | |||
self.param_map = {} # key is param in function, value is input param. | |||
self._checked = False | |||
def evaluate(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def _init_param_map(self, key_map, **kwargs): | |||
self.param_map = {} | |||
for key, value in key_map.items(): | |||
if isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
if isinstance(value, str): | |||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
for key, value in kwargs.items(): | |||
if isinstance(value, str): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
def __call__(self, output_dict, target_dict, force_check=False): | |||
""" | |||
:param output_dict: | |||
:param target_dict: | |||
:return: | |||
""" | |||
if not callable(self.evaluate): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
if not self._checked: | |||
# 1. check consistence between signature and param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = func_spect.args | |||
for func_param, input_param in self.param_map.items(): | |||
if func_param not in func_args: | |||
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.") | |||
# 2. only part of the param_map are passed, left are not | |||
for arg in func_args: | |||
if arg not in self.param_map: | |||
self.param_map[arg] = arg #This param does not need mapping. | |||
self._evaluate_args = func_args | |||
# need to wrap inputs in dict. | |||
mapped_output_dict = {} | |||
mapped_target_dict = {} | |||
for func_arg in self._evaluate_args: | |||
input_arg = self.param_map[func_arg] | |||
if input_arg in output_dict: | |||
mapped_output_dict[func_arg] = output_dict[input_arg] | |||
if input_arg in target_dict: | |||
mapped_target_dict[func_arg] = target_dict[input_arg] | |||
# check duplicated, unused, missing | |||
if force_check or not self._checked: | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | |||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||
for key, value in check_res.items(): | |||
new_value = value.copy() | |||
for idx, func_param in enumerate(value): | |||
if func_param in self._reverse_param_map: | |||
new_value[idx] = self._reverse_param_map[func_param] | |||
if check_res.missing or check_res.duplicated: | |||
raise CheckError(check_res=check_res) | |||
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | |||
metrics = self.evaluate(**refined_args) | |||
if not isinstance(metrics, dict): | |||
raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, " | |||
f"got {type(metrics)}.") | |||
self._checked = True | |||
return metrics | |||
class CheckError(Exception): | |||
def __init__(self, check_res): | |||
err = '' | |||
if check_res.missing: | |||
err += f'Missing: {check_res.missing}\n' | |||
if check_res.duplicated: | |||
err += f'Duplicated: {check_res.duplicated}\n' | |||
self.check_res = check_res | |||
def __str__(self): | |||
pass | |||
class Metric(MetricBase): | |||
def __init__(self, func, key_map, **kwargs): | |||
super().__init__() | |||
pass | |||
def _prepare_metrics(metrics): | |||
""" | |||
Prepare list of Metric based on input | |||
:param metrics: | |||
:return: List[fastNLP.MetricBase] | |||
""" | |||
_metrics = [] | |||
if metrics: | |||
if isinstance(metrics, list): | |||
for metric in metrics: | |||
if isinstance(metric, type): | |||
metric = metric() | |||
if isinstance(metric, MetricBase): | |||
_metrics.append(metric) | |||
else: | |||
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||
elif isinstance(metrics, MetricBase): | |||
_metrics = [metrics] | |||
else: | |||
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." | |||
.format(type(metrics))) | |||
return _metrics | |||
class Evaluator(object): | |||
def __init__(self): | |||
@@ -17,7 +145,6 @@ class Evaluator(object): | |||
""" | |||
raise NotImplementedError | |||
class ClassifyEvaluator(Evaluator): | |||
def __init__(self): | |||
super(ClassifyEvaluator, self).__init__() | |||
@@ -2,42 +2,61 @@ import itertools | |||
from collections import defaultdict | |||
import torch | |||
from torch import nn | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.metrics import _prepare_metrics | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
def __init__(self, data, model, batch_size=16, use_cuda=False): | |||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): | |||
super(Tester, self).__init__() | |||
self.use_cuda = use_cuda | |||
if not isinstance(data, DataSet): | |||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||
self.metrics = _prepare_metrics(metrics) | |||
# check predict | |||
if hasattr(self._model, 'predict'): | |||
self._predict_func = self._model.predict | |||
if not callable(self._predict_func): | |||
_model_name = model.__class__.__name__ | |||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||
f"for evaluation, not `{type(self._predict_func)}`.") | |||
else: | |||
self._predict_func = self._model | |||
self.data = data | |||
self.batch_size = batch_size | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self._model = model.cuda() | |||
else: | |||
self._model = model | |||
if hasattr(self._model, 'predict'): | |||
assert callable(self._model.predict) | |||
self._predict_func = self._model.predict | |||
else: | |||
self._predict_func = self._model | |||
assert hasattr(model, 'evaluate') | |||
self._evaluator = model.evaluate | |||
self.eval_history = [] # evaluation results of all batches | |||
self.use_cuda = use_cuda | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
self._model_device = model.parameters().__next__().device | |||
def test(self): | |||
# turn on the testing mode; clean up the history | |||
network = self._model | |||
self.mode(network, is_test=True) | |||
self.eval_history.clear() | |||
output, truths = defaultdict(list), defaultdict(list) | |||
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||
prediction = self.data_forward(network, batch_x) | |||
assert isinstance(prediction, dict) | |||
for k, v in prediction.items(): | |||
@@ -48,9 +67,13 @@ class Tester(object): | |||
output[k] = itertools.chain(*v) | |||
for k, v in truths.items(): | |||
truths[k] = itertools.chain(*v) | |||
args = _build_args(self._evaluator, **output, **truths) | |||
eval_results = self._evaluator(**args) | |||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||
eval_results = {} | |||
for metric in self.metrics: | |||
eval_result = metric(output, truths) | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
if self.verbose >= 0: | |||
print("[tester] \n{}".format(self.format_eval_results(eval_results))) | |||
self.mode(network, is_test=False) | |||
return eval_results | |||
@@ -72,10 +95,15 @@ class Tester(object): | |||
y = self._predict_func(**x) | |||
return y | |||
def print_eval_results(self, results): | |||
def format_eval_results(self, results): | |||
"""Override this method to support more print formats. | |||
:param results: dict, (str: float) is (metrics name: value) | |||
""" | |||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||
_str = '' | |||
for metric_name, metric_result in results.items(): | |||
_str += metric_name + '\n\t' | |||
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||
_str += '\n' | |||
return _str |
@@ -7,6 +7,7 @@ from datetime import datetime | |||
from datetime import timedelta | |||
import torch | |||
from torch import nn | |||
from tensorboardX import SummaryWriter | |||
from fastNLP.core.batch import Batch | |||
@@ -16,23 +17,50 @@ from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _syn_model_data | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import LossBase | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
class Trainer(object): | |||
"""Main Training Loop | |||
""" | |||
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||
dev_data=None, use_cuda=False, save_path="./save", | |||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||
**kwargs): | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# prepare loss | |||
losser = _prepare_losser(losser) | |||
if need_check_code: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
self.model = model | |||
self.losser = losser | |||
self.metrics = metrics | |||
self.n_epochs = int(n_epochs) | |||
self.batch_size = int(batch_size) | |||
self.use_cuda = bool(use_cuda) | |||
@@ -41,23 +69,19 @@ class Trainer(object): | |||
self.validate_every = int(validate_every) | |||
self._best_accuracy = 0 | |||
if need_check_code: | |||
_check_code(dataset=train_data, model=model, dev_data=dev_data) | |||
self._model_device = model.parameters().__next__().device | |||
# TODO self._best_accuracy不能表现出当前的metric多种的情况 | |||
model_name = model.__class__.__name__ | |||
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) | |||
self.loss_func = self.model.get_loss | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
else: | |||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) | |||
self.evaluator = self.model.evaluate | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
metrics=self.metrics, | |||
batch_size=self.batch_size, | |||
use_cuda=self.use_cuda) | |||
@@ -118,8 +142,9 @@ class Trainer(object): | |||
- epoch: int, | |||
""" | |||
for batch_x, batch_y in data_iterator: | |||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||
prediction = self.data_forward(model, batch_x) | |||
loss = self.get_loss(prediction, batch_y) | |||
self.grad_backward(loss) | |||
self.update() | |||
@@ -169,6 +194,8 @@ class Trainer(object): | |||
def data_forward(self, network, x): | |||
x = _build_args(network.forward, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
return y | |||
def grad_backward(self, loss): | |||
@@ -229,11 +256,11 @@ IGNORE_CHECK_LEVEL = 0 | |||
WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, | |||
check_level=WARNING_CHECK_LEVEL): | |||
# check get_loss 方法 | |||
model_name = model.__class__.__name__ | |||
if not hasattr(model, 'get_loss'): | |||
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
@@ -246,23 +273,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) | |||
func_signature = get_func_signature(model.forward) | |||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||
if not isinstance(output, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") | |||
# loss check | |||
if batch_count == 0: | |||
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, | |||
output=output, batch_y=batch_y) | |||
loss_input = _build_args(model.get_loss, **output, **batch_y) | |||
loss = model.get_loss(**loss_input) | |||
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss | |||
# 需要保证output与batch_y是无歧义的? | |||
# (1) output和batch_y长度为1 | |||
# (2) output和batch_y的key是和losser接受的完全一致 | |||
pass | |||
loss = losser(output, batch_y) | |||
# check loss output | |||
if batch_count == 0: | |||
if not isinstance(loss, torch.Tensor): | |||
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". | |||
format(model_name, type(loss))) | |||
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". | |||
format(type(losser), type(loss))) | |||
if len(loss.size())!=0: | |||
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( | |||
model_name, loss.size() | |||
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( | |||
type(losser), loss.size() | |||
)) | |||
loss.backward() | |||
model.zero_grad() | |||
@@ -270,26 +300,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
break | |||
if dev_data is not None: | |||
if not hasattr(model, 'evaluate'): | |||
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" | |||
"dev_data to 'None'." | |||
.format(model_name)) | |||
outputs, truths = defaultdict(list), defaultdict(list) | |||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
# TODO 这里修改为使用tester | |||
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) | |||
with torch.no_grad(): | |||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||
_syn_model_data(model, batch_x, batch_y) | |||
if hasattr(model, 'predict'): | |||
if not callable(model.predict): | |||
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " | |||
f"for evaluation.") | |||
refined_batch_x = _build_args(model.predict, **batch_x) | |||
prev_func = model.predict | |||
output = prev_func(**refined_batch_x) | |||
func_signature = get_func_signature(model.predict) | |||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||
else: | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
prev_func = model.forward | |||
output = prev_func(**refined_batch_x) | |||
func_signature = get_func_signature(prev_func) | |||
if not isinstance(output, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") | |||
for k, v in output.items(): | |||
outputs[k].append(v) | |||
for k, v in batch_y.items(): | |||
@@ -297,16 +330,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
for k, v in outputs.items(): | |||
outputs[k] = itertools.chain(*v) | |||
outputs[k] = tuple(itertools.chain(*v)) | |||
for k, v in truths.items(): | |||
truths[k] = itertools.chain(*v) | |||
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, | |||
output=outputs, batch_y=truths) | |||
refined_input = _build_args(model.evaluate, **outputs, **truths) | |||
metrics = model.evaluate(**refined_input) | |||
func_signature = get_func_signature(model.evaluate) | |||
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ | |||
format(func_signature) | |||
truths[k] = tuple(itertools.chain(*v)) | |||
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug | |||
def _check_forward_error(model_func, check_level, batch_x): | |||
@@ -3,9 +3,8 @@ import inspect | |||
import os | |||
from collections import Counter | |||
from collections import namedtuple | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||
from collections import defaultdict | |||
import torch | |||
def save_pickle(obj, pickle_path, file_name): | |||
"""Save an object into a pickle file. | |||
@@ -121,14 +120,35 @@ def _check_arg_dict_list(func, args): | |||
input_args = set(input_arg_count.keys()) | |||
missing = list(require_args - input_args) | |||
unused = list(input_args - all_args) | |||
return CheckRes(missing=missing, | |||
unused=unused, | |||
duplicated=duplicated, | |||
required=list(require_args), | |||
all_needed=list(all_args)) | |||
check_res = {} | |||
check_res['missing'] = missing | |||
check_res['unused'] = unused | |||
check_res['duplicated'] = duplicated | |||
check_res['required'] = list(require_args) | |||
check_res['all_needed'] = list(all_args) | |||
return check_res | |||
def get_func_signature(func): | |||
# can only be used in function or class method | |||
""" | |||
Given a function or method, return its signature. | |||
For example: | |||
(1) function | |||
def func(a, b='a', *args): | |||
xxxx | |||
get_func_signature(func) # 'func(a, b='a', *args)' | |||
(2) method | |||
class Demo: | |||
def __init__(self): | |||
xxx | |||
def forward(self, a, b='a', **args) | |||
demo = Demo() | |||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | |||
:param func: a function or a method | |||
:return: str or None | |||
""" | |||
if inspect.ismethod(func): | |||
class_name = func.__self__.__class__.__name__ | |||
signature = inspect.signature(func) | |||
@@ -146,10 +166,16 @@ def get_func_signature(func): | |||
return signature_str | |||
# move data to model's device | |||
import torch | |||
def _syn_model_data(model, *args): | |||
assert len(model.state_dict())!=0, "This model has no parameter." | |||
""" | |||
move data to model's device, element in *args should be dict. This is a inplace change. | |||
:param model: | |||
:param args: | |||
:return: | |||
""" | |||
if len(model.state_dict())==0: | |||
raise ValueError("model has no parameter.") | |||
device = model.parameters().__next__().device | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
@@ -157,4 +183,24 @@ def _syn_model_data(model, *args): | |||
if isinstance(value, torch.Tensor): | |||
arg[key] = value.to(device) | |||
else: | |||
raise ValueError("Only support dict type right now.") | |||
raise TypeError("Only support `dict` type right now.") | |||
def _move_dict_value_to_device(device, *args): | |||
""" | |||
move data to model's device, element in *args should be dict. This is a inplace change. | |||
:param device: torch.device | |||
:param args: | |||
:return: | |||
""" | |||
if not isinstance(device, torch.device): | |||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
for key, value in arg.items(): | |||
if isinstance(value, torch.Tensor): | |||
arg[key] = value.to(device) | |||
else: | |||
raise TypeError("Only support `dict` type right now.") | |||