Browse Source

CheckError add function

tags/v0.2.0^2
yh 6 years ago
parent
commit
0d4720b1d9
4 changed files with 57 additions and 105 deletions
  1. +7
    -21
      fastNLP/core/metrics.py
  2. +18
    -12
      fastNLP/core/tester.py
  3. +22
    -65
      fastNLP/core/trainer.py
  4. +10
    -7
      fastNLP/core/utils.py

+ 7
- 21
fastNLP/core/metrics.py View File

@@ -8,6 +8,8 @@ import torch
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import CheckError



class MetricBase(object): class MetricBase(object):
def __init__(self): def __init__(self):
@@ -29,7 +31,7 @@ class MetricBase(object):
if isinstance(value, str): if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value self.param_map[key] = value
def __call__(self, output_dict, target_dict, force_check=False): def __call__(self, output_dict, target_dict, force_check=False):
""" """
:param output_dict: :param output_dict:
@@ -67,7 +69,7 @@ class MetricBase(object):
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()} self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items(): for key, value in check_res.items():
new_value = value.copy()
new_value = list(value)
for idx, func_param in enumerate(value): for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map: if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param] new_value[idx] = self._reverse_param_map[func_param]
@@ -85,28 +87,12 @@ class MetricBase(object):
return metrics return metrics







class CheckError(Exception):
def __init__(self, check_res):

err = ''
if check_res.missing:
err += f'Missing: {check_res.missing}\n'
if check_res.duplicated:
err += f'Duplicated: {check_res.duplicated}\n'
self.check_res = check_res

def __str__(self):
pass


class Metric(MetricBase): class Metric(MetricBase):
def __init__(self, func, key_map, **kwargs): def __init__(self, func, key_map, **kwargs):
super().__init__() super().__init__()
pass pass



def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """


@@ -127,8 +113,8 @@ def _prepare_metrics(metrics):
elif isinstance(metrics, MetricBase): elif isinstance(metrics, MetricBase):
_metrics = [metrics] _metrics = [metrics]
else: else:
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}."
.format(type(metrics)))
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, "
f"got {type(metrics)}.")
return _metrics return _metrics






+ 18
- 12
fastNLP/core/tester.py View File

@@ -5,12 +5,13 @@ import torch
from torch import nn from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -33,7 +34,7 @@ class Tester(object):
raise TypeError(f"`{_model_name}.predict` must be callable to be used " raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.") f"for evaluation, not `{type(self._predict_func)}`.")
else: else:
self._predict_func = self._model
self._predict_func = self._model.forward


self.data = data self.data = data
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
@@ -50,14 +51,14 @@ class Tester(object):
def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
self.mode(network, is_test=True)
self._mode(network, is_test=True)
output, truths = defaultdict(list), defaultdict(list) output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y) _move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(network, batch_x)
prediction = self._data_forward(self._predict_func, batch_x)
assert isinstance(prediction, dict) assert isinstance(prediction, dict)
for k, v in prediction.items(): for k, v in prediction.items():
output[k].append(v) output[k].append(v)
@@ -68,16 +69,21 @@ class Tester(object):
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v) truths[k] = itertools.chain(*v)
eval_results = {} eval_results = {}
try:
for metric in self.metrics: for metric in self.metrics:
eval_result = metric(output, truths) eval_result = metric(output, truths)
metric_name = metric.__class__.__name__ metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result eval_results[metric_name] = eval_result
except CheckError as e:
pass


if self.verbose >= 0: if self.verbose >= 0:
print("[tester] \n{}".format(self.format_eval_results(eval_results)))
self.mode(network, is_test=False)
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results return eval_results


def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
@@ -89,13 +95,13 @@ class Tester(object):
else: else:
model.train() model.train()


def data_forward(self, network, x):
def _data_forward(self, func, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(network.forward, **x)
y = self._predict_func(**x)
x = _build_args(func, **x)
y = func(**x)
return y return y


def format_eval_results(self, results):
def _format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.


:param results: dict, (str: float) is (metrics name: value) :param results: dict, (str: float) is (metrics name: value)


+ 22
- 65
fastNLP/core/trainer.py View File

@@ -25,7 +25,7 @@ from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop
@@ -211,13 +211,11 @@ class Trainer(object):
def get_loss(self, predict, truth): def get_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


:param predict: prediction label vector
:param truth: ground truth label vector
:param predict: prediction dict, produced by model.forward
:param truth: ground truth dict, produced by batch_y
:return: a scalar :return: a scalar
""" """
assert isinstance(predict, dict) and isinstance(truth, dict)
args = _build_args(self.loss_func, **predict, **truth)
return self.loss_func(**args)
return self.losser(predict, truth)


def save_model(self, model, model_name, only_param=False): def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name) model_name = os.path.join(self.save_path, model_name)
@@ -260,11 +258,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
dev_data=None, dev_data=None,
check_level=WARNING_CHECK_LEVEL): check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法 # check get_loss 方法
model_name = model.__class__.__name__
model_devcie = model.parameters().__next__().device


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check # forward check
if batch_count==0: if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level, _check_forward_error(model_func=model.forward, check_level=check_level,
@@ -277,68 +275,29 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")


# loss check # loss check
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss
# 需要保证output与batch_y是无歧义的?
# (1) output和batch_y长度为1
# (2) output和batch_y的key是和losser接受的完全一致
pass

loss = losser(output, batch_y)

try:
loss = losser(output, batch_y)
except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.".
format(type(losser), type(loss)))
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0: if len(loss.size())!=0:
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
type(losser), loss.size()
))
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward() loss.backward()
model.zero_grad() model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break break


if dev_data is not None: if dev_data is not None:
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, )

with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

if hasattr(model, 'predict'):
if not callable(model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict
output = prev_func(**refined_batch_x)
else:
refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(prev_func)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`")
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
for k, v in outputs.items():
outputs[k] = tuple(itertools.chain(*v))
for k, v in truths.items():
truths[k] = tuple(itertools.chain(*v))
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug





tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
tester.test()




def _check_forward_error(model_func, check_level, batch_x): def _check_forward_error(model_func, check_level, batch_x):
@@ -346,11 +305,11 @@ def _check_forward_error(model_func, check_level, batch_x):
_missing = '' _missing = ''
_unused = '' _unused = ''
func_signature = get_func_signature(model_func) func_signature = get_func_signature(model_func)
if len(check_res.missing)!=0:
if len(check_res['missing'])!=0:
_missing = "Function {} misses {}, only provided with {}, " \ _missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing, ".\n".format(func_signature, check_res.missing,
list(batch_x.keys())) list(batch_x.keys()))
if len(check_res.unused)!=0:
if len(check_res['unused'])!=0:
if len(check_res.unused) > 1: if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused) _unused = "{} are not used ".format(check_res.unused)
else: else:
@@ -370,9 +329,7 @@ def _check_forward_error(model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL: elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused) warnings.warn(message=_unused)


def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):

check_res = _check_arg_dict_list(func, [output, batch_y])
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):
_missing = '' _missing = ''
_unused = '' _unused = ''
_duplicated = '' _duplicated = ''


+ 10
- 7
fastNLP/core/utils.py View File

@@ -220,13 +220,16 @@ class CheckError(Exception):


CheckError. Used in losses.LossBase, metrics.MetricBase. CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res):
def __init__(self, check_res:CheckRes, func_signature:str):
err = '' err = ''
if check_res['missing']:
err += f"Missing: {check_res['missing']}\n"
if check_res['duplicated']:
err += f"Duplicated: {check_res['duplicated']}\n"
if check_res['unused']:
err += f"Unused: {check_res['unused']}\n"
if check_res.missing:
err += f"Missing: {check_res.missing}\n"
if check_res.duplicated:
err += f"Duplicated: {check_res.duplicated}\n"
if check_res.unused:
err += f"Unused: {check_res.unused}\n"

Exception.__init__(self, err) Exception.__init__(self, err)

self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature

Loading…
Cancel
Save