@@ -17,6 +17,29 @@ class Loss(LossBase): | |||||
pass | pass | ||||
class LossInForward(LossBase): | |||||
def __init__(self, loss_key='loss'): | |||||
super().__init__() | |||||
self.loss_key = loss_key | |||||
def get_loss(self, *args, **kwargs): | |||||
pass | |||||
def __call__(self, output_dict, predict_dict): | |||||
pass | |||||
def _prepare_losser(losser): | |||||
if losser is None: | |||||
losser = LossInForward() | |||||
return losser | |||||
elif isinstance(losser, LossBase): | |||||
return losser | |||||
else: | |||||
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}") | |||||
def squash(predict, truth, **kwargs): | def squash(predict, truth, **kwargs): | ||||
'''To reshape tensors in order to fit Loss functions in pytorch | '''To reshape tensors in order to fit Loss functions in pytorch | ||||
@@ -1,8 +1,136 @@ | |||||
import warnings | import warnings | ||||
import inspect | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _build_args | |||||
class MetricBase(object): | |||||
def __init__(self): | |||||
self.param_map = {} # key is param in function, value is input param. | |||||
self._checked = False | |||||
def evaluate(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def _init_param_map(self, key_map, **kwargs): | |||||
self.param_map = {} | |||||
for key, value in key_map.items(): | |||||
if isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
for key, value in kwargs.items(): | |||||
if isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
def __call__(self, output_dict, target_dict, force_check=False): | |||||
""" | |||||
:param output_dict: | |||||
:param target_dict: | |||||
:return: | |||||
""" | |||||
if not callable(self.evaluate): | |||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||||
if not self._checked: | |||||
# 1. check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = func_spect.args | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.") | |||||
# 2. only part of the param_map are passed, left are not | |||||
for arg in func_args: | |||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg #This param does not need mapping. | |||||
self._evaluate_args = func_args | |||||
# need to wrap inputs in dict. | |||||
mapped_output_dict = {} | |||||
mapped_target_dict = {} | |||||
for func_arg in self._evaluate_args: | |||||
input_arg = self.param_map[func_arg] | |||||
if input_arg in output_dict: | |||||
mapped_output_dict[func_arg] = output_dict[input_arg] | |||||
if input_arg in target_dict: | |||||
mapped_target_dict[func_arg] = target_dict[input_arg] | |||||
# check duplicated, unused, missing | |||||
if force_check or not self._checked: | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | |||||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||||
for key, value in check_res.items(): | |||||
new_value = value.copy() | |||||
for idx, func_param in enumerate(value): | |||||
if func_param in self._reverse_param_map: | |||||
new_value[idx] = self._reverse_param_map[func_param] | |||||
if check_res.missing or check_res.duplicated: | |||||
raise CheckError(check_res=check_res) | |||||
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | |||||
metrics = self.evaluate(**refined_args) | |||||
if not isinstance(metrics, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, " | |||||
f"got {type(metrics)}.") | |||||
self._checked = True | |||||
return metrics | |||||
class CheckError(Exception): | |||||
def __init__(self, check_res): | |||||
err = '' | |||||
if check_res.missing: | |||||
err += f'Missing: {check_res.missing}\n' | |||||
if check_res.duplicated: | |||||
err += f'Duplicated: {check_res.duplicated}\n' | |||||
self.check_res = check_res | |||||
def __str__(self): | |||||
pass | |||||
class Metric(MetricBase): | |||||
def __init__(self, func, key_map, **kwargs): | |||||
super().__init__() | |||||
pass | |||||
def _prepare_metrics(metrics): | |||||
""" | |||||
Prepare list of Metric based on input | |||||
:param metrics: | |||||
:return: List[fastNLP.MetricBase] | |||||
""" | |||||
_metrics = [] | |||||
if metrics: | |||||
if isinstance(metrics, list): | |||||
for metric in metrics: | |||||
if isinstance(metric, type): | |||||
metric = metric() | |||||
if isinstance(metric, MetricBase): | |||||
_metrics.append(metric) | |||||
else: | |||||
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
elif isinstance(metrics, MetricBase): | |||||
_metrics = [metrics] | |||||
else: | |||||
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." | |||||
.format(type(metrics))) | |||||
return _metrics | |||||
class Evaluator(object): | class Evaluator(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -17,7 +145,6 @@ class Evaluator(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class ClassifyEvaluator(Evaluator): | class ClassifyEvaluator(Evaluator): | ||||
def __init__(self): | def __init__(self): | ||||
super(ClassifyEvaluator, self).__init__() | super(ClassifyEvaluator, self).__init__() | ||||
@@ -2,32 +2,49 @@ import itertools | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
from torch import nn | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): | def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
self.use_cuda = use_cuda | |||||
if not isinstance(data, DataSet): | |||||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||||
self.metrics = _prepare_metrics(metrics) | |||||
# check predict | |||||
if hasattr(self._model, 'predict'): | |||||
self._predict_func = self._model.predict | |||||
if not callable(self._predict_func): | |||||
_model_name = model.__class__.__name__ | |||||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | |||||
else: | |||||
self._predict_func = self._model | |||||
self.data = data | self.data = data | ||||
self.batch_size = batch_size | |||||
self.verbose = verbose | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self._model = model.cuda() | self._model = model.cuda() | ||||
else: | else: | ||||
self._model = model | self._model = model | ||||
if hasattr(self._model, 'predict'): | |||||
if not callable(self._model.predict): | |||||
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " | |||||
f"for evaluation.") | |||||
self._predict_func = self._model.predict | |||||
else: | |||||
self._predict_func = self._model | |||||
self.use_cuda = use_cuda | |||||
self.batch_size = batch_size | |||||
self.verbose = verbose | |||||
self._model_device = model.parameters().__next__().device | |||||
def test(self): | def test(self): | ||||
@@ -39,6 +56,7 @@ class Tester(object): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
assert isinstance(prediction, dict) | assert isinstance(prediction, dict) | ||||
for k, v in prediction.items(): | for k, v in prediction.items(): | ||||
@@ -49,10 +67,13 @@ class Tester(object): | |||||
output[k] = itertools.chain(*v) | output[k] = itertools.chain(*v) | ||||
for k, v in truths.items(): | for k, v in truths.items(): | ||||
truths[k] = itertools.chain(*v) | truths[k] = itertools.chain(*v) | ||||
# args = _build_args(self._evaluator, **output, **truths) | |||||
eval_results = self._evaluator(**args) | |||||
eval_results = {} | |||||
for metric in self.metrics: | |||||
eval_result = metric(output, truths) | |||||
metric_name = metric.__class__.__name__ | |||||
eval_results[metric_name] = eval_result | |||||
if self.verbose >= 0: | if self.verbose >= 0: | ||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
print("[tester] \n{}".format(self.format_eval_results(eval_results))) | |||||
self.mode(network, is_test=False) | self.mode(network, is_test=False) | ||||
return eval_results | return eval_results | ||||
@@ -74,10 +95,15 @@ class Tester(object): | |||||
y = self._predict_func(**x) | y = self._predict_func(**x) | ||||
return y | return y | ||||
def print_eval_results(self, results): | |||||
def format_eval_results(self, results): | |||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
:param results: dict, (str: float) is (metrics name: value) | :param results: dict, (str: float) is (metrics name: value) | ||||
""" | """ | ||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
_str = '' | |||||
for metric_name, metric_result in results.items(): | |||||
_str += metric_name + '\n\t' | |||||
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
_str += '\n' | |||||
return _str |
@@ -17,10 +17,15 @@ from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _syn_model_data | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import LossBase | |||||
from fastNLP.core.metrics import MetricBase | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
@@ -32,6 +37,25 @@ class Trainer(object): | |||||
**kwargs): | **kwargs): | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||||
# check metrics and dev_data | |||||
if (not metrics) and dev_data is not None: | |||||
raise ValueError("No metric for dev_data evaluation.") | |||||
if metrics and (dev_data is None): | |||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||||
# prepare evaluate | |||||
metrics = _prepare_metrics(metrics) | |||||
# prepare loss | |||||
losser = _prepare_losser(losser) | |||||
if need_check_code: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) | |||||
self.train_data = train_data | self.train_data = train_data | ||||
self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
self.model = model | self.model = model | ||||
@@ -45,10 +69,7 @@ class Trainer(object): | |||||
self.validate_every = int(validate_every) | self.validate_every = int(validate_every) | ||||
self._best_accuracy = 0 | self._best_accuracy = 0 | ||||
# TODO check loss与metrics的类型 | |||||
self._model_device = model.parameters().__next__().device | |||||
# TODO self._best_accuracy不能表现出当前的metric多种的情况 | # TODO self._best_accuracy不能表现出当前的metric多种的情况 | ||||
@@ -72,38 +93,6 @@ class Trainer(object): | |||||
# print(self.__dict__) | # print(self.__dict__) | ||||
def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", | |||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||||
**kwargs): | |||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\ | |||||
format(type(train_data))) | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError("The type of model must be torch.nn.Module, got {}.".\ | |||||
format(type(model))) | |||||
if losser is not None: | |||||
# TODO change | |||||
if not isinstance(losser, None): | |||||
raise TypeError("The type of losser must be xxx, got {}.".\ | |||||
format(type(losser))) | |||||
# check metrics and dev_data | |||||
if (not metrics) and dev_data is not None: | |||||
raise ValueError("No metric for dev_data evaluation.") | |||||
if metrics and (dev_data is None): | |||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||||
# check loss | |||||
if isinstance(losser, type): | |||||
self.losser = losser() | |||||
if not isinstance(self.losser, None): | |||||
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.') | |||||
if need_check_code: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) | |||||
def train(self): | def train(self): | ||||
"""Start Training. | """Start Training. | ||||
@@ -153,8 +142,9 @@ class Trainer(object): | |||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||||
prediction = self.data_forward(model, batch_x) | prediction = self.data_forward(model, batch_x) | ||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
self.update() | self.update() | ||||
@@ -205,7 +195,6 @@ class Trainer(object): | |||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | ||||
return y | return y | ||||
@@ -299,7 +288,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
# check loss output | # check loss output | ||||
if batch_count == 0: | if batch_count == 0: | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise ValueError("The return value of {} should be torch.Tensor, but got {}.". | |||||
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". | |||||
format(type(losser), type(loss))) | format(type(losser), type(loss))) | ||||
if len(loss.size())!=0: | if len(loss.size())!=0: | ||||
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( | raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( | ||||
@@ -314,7 +303,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
outputs, truths = defaultdict(list), defaultdict(list) | outputs, truths = defaultdict(list), defaultdict(list) | ||||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
# TODO 这里修改为使用tester | # TODO 这里修改为使用tester | ||||
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | ||||
@@ -3,11 +3,9 @@ import inspect | |||||
import os | import os | ||||
from collections import Counter | from collections import Counter | ||||
from collections import namedtuple | from collections import namedtuple | ||||
from collections import defaultdict | |||||
import torch | import torch | ||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -89,11 +87,15 @@ def _check_arg_dict_list(func, args): | |||||
input_args = set(input_arg_count.keys()) | input_args = set(input_arg_count.keys()) | ||||
missing = list(require_args - input_args) | missing = list(require_args - input_args) | ||||
unused = list(input_args - all_args) | unused = list(input_args - all_args) | ||||
return CheckRes(missing=missing, | |||||
unused=unused, | |||||
duplicated=duplicated, | |||||
required=list(require_args), | |||||
all_needed=list(all_args)) | |||||
check_res = {} | |||||
check_res['missing'] = missing | |||||
check_res['unused'] = unused | |||||
check_res['duplicated'] = duplicated | |||||
check_res['required'] = list(require_args) | |||||
check_res['all_needed'] = list(all_args) | |||||
return check_res | |||||
def get_func_signature(func): | def get_func_signature(func): | ||||
""" | """ | ||||
@@ -150,31 +152,22 @@ def _syn_model_data(model, *args): | |||||
else: | else: | ||||
raise TypeError("Only support `dict` type right now.") | raise TypeError("Only support `dict` type right now.") | ||||
def _prepare_metrics(metrics): | |||||
def _move_dict_value_to_device(device, *args): | |||||
""" | """ | ||||
Prepare list of Metric based on input | |||||
:param metrics: | |||||
move data to model's device, element in *args should be dict. This is a inplace change. | |||||
:param device: torch.device | |||||
:param args: | |||||
:return: | :return: | ||||
""" | """ | ||||
_metrics = [] | |||||
if metrics: | |||||
if isinstance(metrics, list): | |||||
for metric in metrics: | |||||
if isinstance(metric, type): | |||||
metric = metric() | |||||
if isinstance(metric, None): | |||||
_metrics.append(metric) | |||||
else: | |||||
raise TypeError("The type of metric in metrics must be xxxx, not {}.".format( | |||||
type(), type(metric) | |||||
)) | |||||
elif isinstance(metrics, None): | |||||
_metrics = [metrics] | |||||
else: | |||||
raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format( | |||||
type(metrics) | |||||
)) | |||||
if not isinstance(device, torch.device): | |||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||||
return _metrics | |||||
for arg in args: | |||||
if isinstance(arg, dict): | |||||
for key, value in arg.items(): | |||||
if isinstance(value, torch.Tensor): | |||||
arg[key] = value.to(device) | |||||
else: | |||||
raise TypeError("Only support `dict` type right now.") | |||||