Browse Source

trainer and tester change check_code

tags/v0.2.0^2
yh 6 years ago
parent
commit
3a4a729314
4 changed files with 103 additions and 125 deletions
  1. +10
    -5
      fastNLP/core/metrics.py
  2. +4
    -2
      fastNLP/core/tester.py
  3. +20
    -110
      fastNLP/core/trainer.py
  4. +69
    -8
      fastNLP/core/utils.py

+ 10
- 5
fastNLP/core/metrics.py View File

@@ -1,6 +1,7 @@

import warnings
import inspect
from collections import defaultdict

import numpy as np
import torch
@@ -21,6 +22,7 @@ class MetricBase(object):

def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
value_counter = defaultdict(0)
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
@@ -32,16 +34,19 @@ class MetricBase(object):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value

def __call__(self, output_dict, target_dict, force_check=False):
def __call__(self, output_dict, target_dict, check=False):
"""
:param output_dict:
:param target_dict:
:param check: boolean,
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not self._checked:
# 0. check param_map does not have same value

# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
@@ -65,7 +70,7 @@ class MetricBase(object):
mapped_target_dict[func_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if force_check or not self._checked:
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
@@ -73,8 +78,9 @@ class MetricBase(object):
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res)
if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)

metrics = self.evaluate(**refined_args)
@@ -92,7 +98,6 @@ class Metric(MetricBase):
super().__init__()
pass


def _prepare_metrics(metrics):
"""



+ 4
- 2
fastNLP/core/tester.py View File

@@ -12,6 +12,7 @@ from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate

class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -47,7 +48,6 @@ class Tester(object):

self._model_device = model.parameters().__next__().device


def test(self):
# turn on the testing mode; clean up the history
network = self._model
@@ -75,7 +75,9 @@ class Tester(object):
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
except CheckError as e:
pass
prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths)


if self.verbose >= 0:


+ 20
- 110
fastNLP/core/trainer.py View File

@@ -20,12 +20,11 @@ from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error

class Trainer(object):
"""Main Training Loop
@@ -33,7 +32,7 @@ class Trainer(object):
"""
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0,
**kwargs):
super(Trainer, self).__init__()

@@ -53,8 +52,9 @@ class Trainer(object):
# prepare loss
losser = _prepare_losser(losser)

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)
if check_code_level>-1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level)

self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
@@ -250,13 +250,9 @@ class Trainer(object):
DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
check_level=0):
# check get_loss 方法
model_devcie = model.parameters().__next__().device

@@ -265,7 +261,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
_check_forward_error(forward_func=model.forward, check_level=check_level,
batch_x=batch_x)

refined_batch_x = _build_args(model.forward, **batch_x)
@@ -277,19 +273,21 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# loss check
try:
loss = losser(output, batch_y)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size())!=0:
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break
@@ -300,93 +298,5 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
tester.test()


def _check_forward_error(model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
func_signature = get_func_signature(model_func)
if len(check_res['missing'])!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res['unused'])!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if len(_unused)>0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
raise TypeError(_error_str)
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
raise ValueError(_unused)
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):
_missing = ''
_unused = ''
_duplicated = ''
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_strs = []
if _number_errs > 1:
count = 0
order_words = ['Firstly', 'Secondly', 'Thirdly']
if _missing:
_error_strs.append('{}, {}'.format(order_words[count], _missing))
count += 1
if _duplicated:
_error_strs.append('{}, {}'.format(order_words[count], _duplicated))
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_strs.append('{}, {}'.format(order_words[count], _unused))
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_strs.append(_unused)
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
if _missing:
_error_strs.append(_missing)
if _duplicated:
_error_strs.append(_duplicated)

if _error_strs:
raise ValueError('\n' + '\n'.join(_error_strs))


+ 69
- 8
fastNLP/core/utils.py View File

@@ -1,11 +1,14 @@
import _pickle
import inspect
import os
import warnings
from collections import Counter
from collections import namedtuple
import torch

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)
def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.

@@ -105,7 +108,6 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set([arg for arg in spect.args if arg!='self'])
defaults = []
if spect.defaults is not None:
@@ -125,7 +127,8 @@ def _check_arg_dict_list(func, args):
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))
all_needed=list(all_args),
varargs=[arg for arg in spect.varargs])

def get_func_signature(func):
"""
@@ -221,15 +224,73 @@ class CheckError(Exception):
CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res:CheckRes, func_signature:str):
err = ''
errs = [f'The following problems occurred when calling {func_signature}']

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
err += f"Missing: {check_res.missing}\n"
errs.append(f"\tmissing param: {check_res.missing}")
if check_res.duplicated:
err += f"Duplicated: {check_res.duplicated}\n"
errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused:
err += f"Unused: {check_res.unused}\n"
errs.append(f"\tunused param: {check_res.unused}")

Exception.__init__(self, err)
Exception.__init__(self, '\n'.join(errs))

self.check_res = check_res
self.func_signature = func_signature

IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
output:dict, batch_y:dict, check_level=0):
errs = []
_unused = []
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}"
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)


def _check_forward_error(forward_func, batch_x, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)

errs = []
_unused = []

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)

Loading…
Cancel
Save