@@ -1,4 +1,5 @@ | |||
import torch | |||
import numpy as np | |||
class Batch(object): | |||
@@ -45,7 +46,7 @@ class Batch(object): | |||
if field.is_target or field.is_input: | |||
batch = field.get(indices) | |||
if not self.as_numpy: | |||
batch = torch.from_numpy(batch) | |||
batch = to_tensor(batch, field.dtype) | |||
if field.is_target: | |||
batch_y[field_name] = batch | |||
if field.is_input: | |||
@@ -54,3 +55,10 @@ class Batch(object): | |||
self.curidx = endidx | |||
return batch_x, batch_y | |||
def to_tensor(batch, dtype): | |||
if dtype in (np.int8, np.int16, np.int32, np.int64): | |||
batch = torch.LongTensor(batch) | |||
if dtype in (np.float32, np.float64): | |||
batch = torch.FloatTensor(batch) | |||
return batch |
@@ -39,7 +39,7 @@ class FieldArray(object): | |||
@staticmethod | |||
def _map_to_np_type(basic_type): | |||
type_mapping = {int: np.int64, float: np.double, str: np.str} | |||
type_mapping = {int: np.int64, float: np.float64, str: np.str} | |||
return type_mapping[basic_type] | |||
def __repr__(self): | |||
@@ -126,15 +126,30 @@ class NLLLoss(LossBase): | |||
class LossInForward(LossBase): | |||
def __init__(self, loss_key='loss'): | |||
super().__init__() | |||
if not isinstance(loss_key, str): | |||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||
self.loss_key = loss_key | |||
def get_loss(self, **kwargs): | |||
if self.loss_key not in kwargs: | |||
pass | |||
check_res = CheckRes(missing=[self.loss_key], | |||
unused=[], | |||
duplicated=[], | |||
required=[], | |||
all_needed=[], | |||
varargs=[]) | |||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||
def __call__(self, output_dict, predict_dict): | |||
def __call__(self, output_dict, predict_dict, force_check=False): | |||
return self.get_loss(**output_dict) | |||
loss = self.get_loss(**output_dict) | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") | |||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||
return loss | |||
def _prepare_losser(losser): | |||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _check_function_or_method | |||
class MetricBase(object): | |||
def __init__(self): | |||
@@ -20,19 +20,32 @@ class MetricBase(object): | |||
def evaluate(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def _init_param_map(self, key_map, **kwargs): | |||
self.param_map = {} | |||
value_counter = defaultdict(0) | |||
for key, value in key_map.items(): | |||
if isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
if isinstance(value, str): | |||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
value_counter = defaultdict(set) | |||
if key_map is not None: | |||
if not isinstance(key_map, dict): | |||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||
for key, value in key_map.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
continue | |||
if isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
if isinstance(value, str): | |||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
value_counter[value].add(key) | |||
for key, value in kwargs.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
continue | |||
if isinstance(value, str): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
value_counter[value].add(key) | |||
for value, key_set in value_counter.items(): | |||
if len(key_set)>1: | |||
raise ValueError(f"Several params:{key_set} are provided with one output {value}.") | |||
def __call__(self, output_dict, target_dict, check=False): | |||
""" | |||
@@ -45,8 +58,6 @@ class MetricBase(object): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
if not self._checked: | |||
# 0. check param_map does not have same value | |||
# 1. check consistence between signature and param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = func_spect.args | |||
@@ -58,26 +69,32 @@ class MetricBase(object): | |||
if arg not in self.param_map: | |||
self.param_map[arg] = arg #This param does not need mapping. | |||
self._evaluate_args = func_args | |||
self._reverse_param_map = {value: key for key, value in self.param_map.items()} | |||
# need to wrap inputs in dict. | |||
mapped_output_dict = {} | |||
mapped_target_dict = {} | |||
for func_arg in self._evaluate_args: | |||
input_arg = self.param_map[func_arg] | |||
if input_arg in self._reverse_param_map: | |||
mapped_arg = func_arg | |||
else: | |||
mapped_arg = input_arg | |||
if input_arg in output_dict: | |||
mapped_output_dict[func_arg] = output_dict[input_arg] | |||
mapped_output_dict[mapped_arg] = output_dict[input_arg] | |||
if input_arg in target_dict: | |||
mapped_target_dict[func_arg] = target_dict[input_arg] | |||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||
# check duplicated, unused, missing | |||
if check or not self._checked: | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | |||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||
for key, value in check_res.items(): | |||
new_value = list(value) | |||
for idx, func_param in enumerate(value): | |||
if func_param in self._reverse_param_map: | |||
new_value[idx] = self._reverse_param_map[func_param] | |||
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})' | |||
else: | |||
new_value[idx] = func_param | |||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||
raise CheckError(check_res=check_res, | |||
func_signature=get_func_signature(self.evaluate)) | |||
@@ -93,11 +110,55 @@ class MetricBase(object): | |||
return metrics | |||
class Metric(MetricBase): | |||
class FuncMetric(MetricBase): | |||
def __init__(self, func, key_map, **kwargs): | |||
super().__init__() | |||
_check_function_or_method(func=func) | |||
self._init_param_map(key_map=key_map, **kwargs) | |||
self.evaluate = func | |||
class AccuracyMetric(MetricBase): | |||
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None): | |||
super().__init__() | |||
self._init_param_map(predictions=predictions, targets=targets, | |||
masks=masks, seq_lens=seq_lens) | |||
def evaluate(self, predictions, targets, masks=None, seq_lens=None): | |||
""" | |||
:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||
torch.Size([]), torch.Size([n_classes,]), torch.Size([max_len,]), torch.Size([max_len, n_classes]) | |||
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
torch.Size([]), torch.Size([]), torch.Size([max_len,]), torch.Size([max_len, ]) | |||
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
None, None, torch.Size([max_len,], torch.Size([max_len, ]) | |||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
None, None, torch.Size([1], torch.Size([1]) | |||
:return: dict({'acc': float}) | |||
""" | |||
pass | |||
def _check_evaluate_param(self, predictions, targets, masks=None, seq_lens=None): | |||
# check the validity of self.evaluate param | |||
prediction = predictions[0] | |||
target = targets[0] | |||
if len(np.shape(prediction))==len(target): | |||
pass | |||
if masks is not None: | |||
mask = masks[0] | |||
if seq_lens is not None: | |||
seq_len = seq_lens[0] | |||
def _prepare_metrics(metrics): | |||
""" | |||
@@ -7,11 +7,11 @@ from torch import nn | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
class Tester(object): | |||
@@ -57,7 +57,7 @@ class Tester(object): | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
prediction = self._data_forward(self._predict_func, batch_x) | |||
assert isinstance(prediction, dict) | |||
for k, v in prediction.items(): | |||
@@ -77,7 +77,7 @@ class Tester(object): | |||
except CheckError as e: | |||
prev_func_signature = get_func_signature(self._predict_func) | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, output=output, batch_y=truths) | |||
check_res=e.check_res, output=output, batch_y=truths, check_level=0) | |||
if self.verbose >= 0: | |||
@@ -1,6 +1,5 @@ | |||
import os | |||
import time | |||
import warnings | |||
from datetime import datetime | |||
from datetime import timedelta | |||
@@ -9,24 +8,19 @@ from tensorboardX import SummaryWriter | |||
from torch import nn | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.optimizer import Adam | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
from fastNLP.core.utils import _check_forward_error | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
class Trainer(object): | |||
"""Main Training Loop | |||
@@ -52,6 +46,9 @@ class Trainer(object): | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check save_path | |||
if not (save_path is None or isinstance(save_path, str)): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
@@ -156,7 +153,7 @@ class Trainer(object): | |||
""" | |||
for batch_x, batch_y in data_iterator: | |||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
prediction = self._data_forward(model, batch_x) | |||
loss = self._compute_loss(prediction, batch_y) | |||
self._grad_backward(loss) | |||
@@ -232,11 +229,12 @@ class Trainer(object): | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
model_name = os.path.join(self.save_path, model_name) | |||
if only_param: | |||
torch.save(model.state_dict(), model_name) | |||
else: | |||
torch.save(model, model_name) | |||
if self.save_path is not None: | |||
model_name = os.path.join(self.save_path, model_name) | |||
if only_param: | |||
torch.save(model.state_dict(), model_name) | |||
else: | |||
torch.save(model, model_name) | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -297,7 +295,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
if batch_count==0: | |||
_check_forward_error(forward_func=model.forward, check_level=check_level, | |||
@@ -335,6 +333,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
if dev_data is not None: | |||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
tester.test() | |||
evaluate_results = tester.test() | |||
# TODO 这里需要检查是否返回来的值是否是合理的 | |||
@@ -122,13 +122,13 @@ def _check_arg_dict_list(func, args): | |||
input_args = set(input_arg_count.keys()) | |||
missing = list(require_args - input_args) | |||
unused = list(input_args - all_args) | |||
varargs = [] if spect.varargs else [arg for arg in spect.varargs] | |||
return CheckRes(missing=missing, | |||
unused=unused, | |||
duplicated=duplicated, | |||
required=list(require_args), | |||
all_needed=list(all_args), | |||
varargs=[arg for arg in spect.varargs]) | |||
varargs=varargs) | |||
def get_func_signature(func): | |||
""" | |||
@@ -165,6 +165,7 @@ def get_func_signature(func): | |||
signature_str = func.__name__ + signature_str | |||
return signature_str | |||
def _is_function_or_method(func): | |||
""" | |||
@@ -179,26 +180,8 @@ def _check_function_or_method(func): | |||
if not _is_function_or_method(func): | |||
raise TypeError(f"{type(func)} is not a method or function.") | |||
def _syn_model_data(model, *args): | |||
""" | |||
move data to model's device, element in *args should be dict. This is a inplace change. | |||
:param model: | |||
:param args: | |||
:return: | |||
""" | |||
if len(model.state_dict())==0: | |||
raise ValueError("model has no parameter.") | |||
device = model.parameters().__next__().device | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
for key, value in arg.items(): | |||
if isinstance(value, torch.Tensor): | |||
arg[key] = value.to(device) | |||
else: | |||
raise TypeError("Only support `dict` type right now.") | |||
def _move_dict_value_to_device(device, *args): | |||
def _move_dict_value_to_device(*args, device:torch.device): | |||
""" | |||
move data to model's device, element in *args should be dict. This is a inplace change. | |||
@@ -240,6 +223,7 @@ class CheckError(Exception): | |||
self.check_res = check_res | |||
self.func_signature = func_signature | |||
IGNORE_CHECK_LEVEL = 0 | |||
WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
@@ -252,8 +236,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " | |||
f"please delete it.)") | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" | |||
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") | |||
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`" | |||
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).") | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " | |||
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") | |||
@@ -281,7 +265,7 @@ def _check_forward_error(forward_func, batch_x, check_level): | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") | |||
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.") | |||
if check_res.unused: | |||
_unused = [f"\tunused param: {check_res.unused}"] | |||
if check_level == STRICT_CHECK_LEVEL: | |||