Browse Source

CheckError add function

tags/v0.2.0^2
yh 6 years ago
parent
commit
0d4720b1d9
4 changed files with 57 additions and 105 deletions
  1. +7
    -21
      fastNLP/core/metrics.py
  2. +18
    -12
      fastNLP/core/tester.py
  3. +22
    -65
      fastNLP/core/trainer.py
  4. +10
    -7
      fastNLP/core/utils.py

+ 7
- 21
fastNLP/core/metrics.py View File

@@ -8,6 +8,8 @@ import torch
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import CheckError


class MetricBase(object):
def __init__(self):
@@ -29,7 +31,7 @@ class MetricBase(object):
if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict:
@@ -67,7 +69,7 @@ class MetricBase(object):
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
new_value = value.copy()
new_value = list(value)
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
@@ -85,28 +87,12 @@ class MetricBase(object):
return metrics





class CheckError(Exception):
def __init__(self, check_res):

err = ''
if check_res.missing:
err += f'Missing: {check_res.missing}\n'
if check_res.duplicated:
err += f'Duplicated: {check_res.duplicated}\n'
self.check_res = check_res

def __str__(self):
pass


class Metric(MetricBase):
def __init__(self, func, key_map, **kwargs):
super().__init__()
pass


def _prepare_metrics(metrics):
"""

@@ -127,8 +113,8 @@ def _prepare_metrics(metrics):
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}."
.format(type(metrics)))
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, "
f"got {type(metrics)}.")
return _metrics




+ 18
- 12
fastNLP/core/tester.py View File

@@ -5,12 +5,13 @@ import torch
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError

class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -33,7 +34,7 @@ class Tester(object):
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model
self._predict_func = self._model.forward

self.data = data
if torch.cuda.is_available() and self.use_cuda:
@@ -50,14 +51,14 @@ class Tester(object):
def test(self):
# turn on the testing mode; clean up the history
network = self._model
self.mode(network, is_test=True)
self._mode(network, is_test=True)
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)

with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(network, batch_x)
prediction = self._data_forward(self._predict_func, batch_x)
assert isinstance(prediction, dict)
for k, v in prediction.items():
output[k].append(v)
@@ -68,16 +69,21 @@ class Tester(object):
for k, v in truths.items():
truths[k] = itertools.chain(*v)
eval_results = {}
try:
for metric in self.metrics:
eval_result = metric(output, truths)
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
except CheckError as e:
pass


if self.verbose >= 0:
print("[tester] \n{}".format(self.format_eval_results(eval_results)))
self.mode(network, is_test=False)
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results

def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
@@ -89,13 +95,13 @@ class Tester(object):
else:
model.train()

def data_forward(self, network, x):
def _data_forward(self, func, x):
"""A forward pass of the model. """
x = _build_args(network.forward, **x)
y = self._predict_func(**x)
x = _build_args(func, **x)
y = func(**x)
return y

def format_eval_results(self, results):
def _format_eval_results(self, results):
"""Override this method to support more print formats.

:param results: dict, (str: float) is (metrics name: value)


+ 22
- 65
fastNLP/core/trainer.py View File

@@ -25,7 +25,7 @@ from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError

class Trainer(object):
"""Main Training Loop
@@ -211,13 +211,11 @@ class Trainer(object):
def get_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.

:param predict: prediction label vector
:param truth: ground truth label vector
:param predict: prediction dict, produced by model.forward
:param truth: ground truth dict, produced by batch_y
:return: a scalar
"""
assert isinstance(predict, dict) and isinstance(truth, dict)
args = _build_args(self.loss_func, **predict, **truth)
return self.loss_func(**args)
return self.losser(predict, truth)

def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
@@ -260,11 +258,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法
model_name = model.__class__.__name__
model_devcie = model.parameters().__next__().device

batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
@@ -277,68 +275,29 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")

# loss check
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss
# 需要保证output与batch_y是无歧义的?
# (1) output和batch_y长度为1
# (2) output和batch_y的key是和losser接受的完全一致
pass

loss = losser(output, batch_y)

try:
loss = losser(output, batch_y)
except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.".
format(type(losser), type(loss)))
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0:
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
type(losser), loss.size()
))
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break

if dev_data is not None:
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, )

with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

if hasattr(model, 'predict'):
if not callable(model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict
output = prev_func(**refined_batch_x)
else:
refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(prev_func)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`")
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
for k, v in outputs.items():
outputs[k] = tuple(itertools.chain(*v))
for k, v in truths.items():
truths[k] = tuple(itertools.chain(*v))
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug





tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
tester.test()


def _check_forward_error(model_func, check_level, batch_x):
@@ -346,11 +305,11 @@ def _check_forward_error(model_func, check_level, batch_x):
_missing = ''
_unused = ''
func_signature = get_func_signature(model_func)
if len(check_res.missing)!=0:
if len(check_res['missing'])!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res.unused)!=0:
if len(check_res['unused'])!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
@@ -370,9 +329,7 @@ def _check_forward_error(model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):

check_res = _check_arg_dict_list(func, [output, batch_y])
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):
_missing = ''
_unused = ''
_duplicated = ''


+ 10
- 7
fastNLP/core/utils.py View File

@@ -220,13 +220,16 @@ class CheckError(Exception):

CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res):
def __init__(self, check_res:CheckRes, func_signature:str):
err = ''
if check_res['missing']:
err += f"Missing: {check_res['missing']}\n"
if check_res['duplicated']:
err += f"Duplicated: {check_res['duplicated']}\n"
if check_res['unused']:
err += f"Unused: {check_res['unused']}\n"
if check_res.missing:
err += f"Missing: {check_res.missing}\n"
if check_res.duplicated:
err += f"Duplicated: {check_res.duplicated}\n"
if check_res.unused:
err += f"Unused: {check_res.unused}\n"

Exception.__init__(self, err)

self.check_res = check_res
self.func_signature = func_signature

Loading…
Cancel
Save