Browse Source

trainer and tester change check_code

tags/v0.2.0^2
yh 6 years ago
parent
commit
3a4a729314
4 changed files with 103 additions and 125 deletions
  1. +10
    -5
      fastNLP/core/metrics.py
  2. +4
    -2
      fastNLP/core/tester.py
  3. +20
    -110
      fastNLP/core/trainer.py
  4. +69
    -8
      fastNLP/core/utils.py

+ 10
- 5
fastNLP/core/metrics.py View File

@@ -1,6 +1,7 @@


import warnings import warnings
import inspect import inspect
from collections import defaultdict


import numpy as np import numpy as np
import torch import torch
@@ -21,6 +22,7 @@ class MetricBase(object):


def _init_param_map(self, key_map, **kwargs): def _init_param_map(self, key_map, **kwargs):
self.param_map = {} self.param_map = {}
value_counter = defaultdict(0)
for key, value in key_map.items(): for key, value in key_map.items():
if isinstance(key, str): if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
@@ -32,16 +34,19 @@ class MetricBase(object):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value self.param_map[key] = value


def __call__(self, output_dict, target_dict, force_check=False):
def __call__(self, output_dict, target_dict, check=False):
""" """
:param output_dict: :param output_dict:
:param target_dict: :param target_dict:
:param check: boolean,
:return: :return:
""" """
if not callable(self.evaluate): if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


if not self._checked: if not self._checked:
# 0. check param_map does not have same value

# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args func_args = func_spect.args
@@ -65,7 +70,7 @@ class MetricBase(object):
mapped_target_dict[func_arg] = target_dict[input_arg] mapped_target_dict[func_arg] = target_dict[input_arg]


# check duplicated, unused, missing # check duplicated, unused, missing
if force_check or not self._checked:
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()} self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items(): for key, value in check_res.items():
@@ -73,8 +78,9 @@ class MetricBase(object):
for idx, func_param in enumerate(value): for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map: if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param] new_value[idx] = self._reverse_param_map[func_param]
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res)
if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)


metrics = self.evaluate(**refined_args) metrics = self.evaluate(**refined_args)
@@ -92,7 +98,6 @@ class Metric(MetricBase):
super().__init__() super().__init__()
pass pass



def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """




+ 4
- 2
fastNLP/core/tester.py View File

@@ -12,6 +12,7 @@ from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -47,7 +48,6 @@ class Tester(object):


self._model_device = model.parameters().__next__().device self._model_device = model.parameters().__next__().device



def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
@@ -75,7 +75,9 @@ class Tester(object):
metric_name = metric.__class__.__name__ metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result eval_results[metric_name] = eval_result
except CheckError as e: except CheckError as e:
pass
prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths)




if self.verbose >= 0: if self.verbose >= 0:


+ 20
- 110
fastNLP/core/trainer.py View File

@@ -20,12 +20,11 @@ from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop
@@ -33,7 +32,7 @@ class Trainer(object):
""" """
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0,
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -53,8 +52,9 @@ class Trainer(object):
# prepare loss # prepare loss
losser = _prepare_losser(losser) losser = _prepare_losser(losser)


if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)
if check_code_level>-1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level)


self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
@@ -250,13 +250,9 @@ class Trainer(object):
DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2


IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, dev_data=None,
check_level=WARNING_CHECK_LEVEL):
check_level=0):
# check get_loss 方法 # check get_loss 方法
model_devcie = model.parameters().__next__().device model_devcie = model.parameters().__next__().device


@@ -265,7 +261,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
_move_dict_value_to_device(model_devcie, batch_x, batch_y) _move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check # forward check
if batch_count==0: if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
_check_forward_error(forward_func=model.forward, check_level=check_level,
batch_x=batch_x) batch_x=batch_x)


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
@@ -277,19 +273,21 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# loss check # loss check
try: try:
loss = losser(output, batch_y) loss = losser(output, batch_y)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
except CheckError as e: except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, _check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y, check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level) check_level=check_level)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0:
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
model.zero_grad() model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break break
@@ -300,93 +298,5 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
tester.test() tester.test()




def _check_forward_error(model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
func_signature = get_func_signature(model_func)
if len(check_res['missing'])!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res['unused'])!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if len(_unused)>0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
raise TypeError(_error_str)
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
raise ValueError(_unused)
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):
_missing = ''
_unused = ''
_duplicated = ''
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_strs = []
if _number_errs > 1:
count = 0
order_words = ['Firstly', 'Secondly', 'Thirdly']
if _missing:
_error_strs.append('{}, {}'.format(order_words[count], _missing))
count += 1
if _duplicated:
_error_strs.append('{}, {}'.format(order_words[count], _duplicated))
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_strs.append('{}, {}'.format(order_words[count], _unused))
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_strs.append(_unused)
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
if _missing:
_error_strs.append(_missing)
if _duplicated:
_error_strs.append(_duplicated)


if _error_strs:
raise ValueError('\n' + '\n'.join(_error_strs))


+ 69
- 8
fastNLP/core/utils.py View File

@@ -1,11 +1,14 @@
import _pickle import _pickle
import inspect import inspect
import os import os
import warnings
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple
import torch import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)
def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.


@@ -105,7 +108,6 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set([arg for arg in spect.args if arg!='self']) all_args = set([arg for arg in spect.args if arg!='self'])
defaults = [] defaults = []
if spect.defaults is not None: if spect.defaults is not None:
@@ -125,7 +127,8 @@ def _check_arg_dict_list(func, args):
unused=unused, unused=unused,
duplicated=duplicated, duplicated=duplicated,
required=list(require_args), required=list(require_args),
all_needed=list(all_args))
all_needed=list(all_args),
varargs=[arg for arg in spect.varargs])


def get_func_signature(func): def get_func_signature(func):
""" """
@@ -221,15 +224,73 @@ class CheckError(Exception):
CheckError. Used in losses.LossBase, metrics.MetricBase. CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res:CheckRes, func_signature:str): def __init__(self, check_res:CheckRes, func_signature:str):
err = ''
errs = [f'The following problems occurred when calling {func_signature}']

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
err += f"Missing: {check_res.missing}\n"
errs.append(f"\tmissing param: {check_res.missing}")
if check_res.duplicated: if check_res.duplicated:
err += f"Duplicated: {check_res.duplicated}\n"
errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused: if check_res.unused:
err += f"Unused: {check_res.unused}\n"
errs.append(f"\tunused param: {check_res.unused}")


Exception.__init__(self, err)
Exception.__init__(self, '\n'.join(errs))


self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature self.func_signature = func_signature

IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
output:dict, batch_y:dict, check_level=0):
errs = []
_unused = []
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}"
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)


def _check_forward_error(forward_func, batch_x, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)

errs = []
_unused = []

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)

Loading…
Cancel
Save