Browse Source

change the calculation of metric to batch by batch. The older design is to concat all data before calculation.

tags/v0.2.0^2
yh 6 years ago
parent
commit
f24fca1b21
7 changed files with 133 additions and 66 deletions
  1. +9
    -1
      fastNLP/core/batch.py
  2. +1
    -1
      fastNLP/core/fieldarray.py
  3. +18
    -3
      fastNLP/core/losses.py
  4. +78
    -17
      fastNLP/core/metrics.py
  5. +3
    -3
      fastNLP/core/tester.py
  6. +16
    -17
      fastNLP/core/trainer.py
  7. +8
    -24
      fastNLP/core/utils.py

+ 9
- 1
fastNLP/core/batch.py View File

@@ -1,4 +1,5 @@
import torch import torch
import numpy as np




class Batch(object): class Batch(object):
@@ -45,7 +46,7 @@ class Batch(object):
if field.is_target or field.is_input: if field.is_target or field.is_input:
batch = field.get(indices) batch = field.get(indices)
if not self.as_numpy: if not self.as_numpy:
batch = torch.from_numpy(batch)
batch = to_tensor(batch, field.dtype)
if field.is_target: if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
if field.is_input: if field.is_input:
@@ -54,3 +55,10 @@ class Batch(object):
self.curidx = endidx self.curidx = endidx


return batch_x, batch_y return batch_x, batch_y

def to_tensor(batch, dtype):
if dtype in (np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (np.float32, np.float64):
batch = torch.FloatTensor(batch)
return batch

+ 1
- 1
fastNLP/core/fieldarray.py View File

@@ -39,7 +39,7 @@ class FieldArray(object):


@staticmethod @staticmethod
def _map_to_np_type(basic_type): def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.double, str: np.str}
type_mapping = {int: np.int64, float: np.float64, str: np.str}
return type_mapping[basic_type] return type_mapping[basic_type]


def __repr__(self): def __repr__(self):


+ 18
- 3
fastNLP/core/losses.py View File

@@ -126,15 +126,30 @@ class NLLLoss(LossBase):
class LossInForward(LossBase): class LossInForward(LossBase):
def __init__(self, loss_key='loss'): def __init__(self, loss_key='loss'):
super().__init__() super().__init__()
if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
self.loss_key = loss_key self.loss_key = loss_key


def get_loss(self, **kwargs): def get_loss(self, **kwargs):
if self.loss_key not in kwargs: if self.loss_key not in kwargs:
pass
check_res = CheckRes(missing=[self.loss_key],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))


def __call__(self, output_dict, predict_dict):
def __call__(self, output_dict, predict_dict, force_check=False):


return self.get_loss(**output_dict)
loss = self.get_loss(**output_dict)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")

return loss




def _prepare_losser(losser): def _prepare_losser(losser):


+ 78
- 17
fastNLP/core/metrics.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_function_or_method


class MetricBase(object): class MetricBase(object):
def __init__(self): def __init__(self):
@@ -20,19 +20,32 @@ class MetricBase(object):
def evaluate(self, *args, **kwargs): def evaluate(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
value_counter = defaultdict(0)
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def _init_param_map(self, key_map=None, **kwargs):
value_counter = defaultdict(set)
if key_map is not None:
if not isinstance(key_map, dict):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
continue
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items(): for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
continue
if isinstance(value, str): if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value self.param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set)>1:
raise ValueError(f"Several params:{key_set} are provided with one output {value}.")


def __call__(self, output_dict, target_dict, check=False): def __call__(self, output_dict, target_dict, check=False):
""" """
@@ -45,8 +58,6 @@ class MetricBase(object):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


if not self._checked: if not self._checked:
# 0. check param_map does not have same value

# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args func_args = func_spect.args
@@ -58,26 +69,32 @@ class MetricBase(object):
if arg not in self.param_map: if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping. self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args self._evaluate_args = func_args
self._reverse_param_map = {value: key for key, value in self.param_map.items()}


# need to wrap inputs in dict. # need to wrap inputs in dict.
mapped_output_dict = {} mapped_output_dict = {}
mapped_target_dict = {} mapped_target_dict = {}
for func_arg in self._evaluate_args: for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg] input_arg = self.param_map[func_arg]
if input_arg in self._reverse_param_map:
mapped_arg = func_arg
else:
mapped_arg = input_arg
if input_arg in output_dict: if input_arg in output_dict:
mapped_output_dict[func_arg] = output_dict[input_arg]
mapped_output_dict[mapped_arg] = output_dict[input_arg]
if input_arg in target_dict: if input_arg in target_dict:
mapped_target_dict[func_arg] = target_dict[input_arg]
mapped_target_dict[mapped_arg] = target_dict[input_arg]


# check duplicated, unused, missing # check duplicated, unused, missing
if check or not self._checked: if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items(): for key, value in check_res.items():
new_value = list(value) new_value = list(value)
for idx, func_param in enumerate(value): for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map: if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})'
else:
new_value[idx] = func_param
if check_res.missing or check_res.duplicated or check_res.varargs: if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res, raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate)) func_signature=get_func_signature(self.evaluate))
@@ -93,11 +110,55 @@ class MetricBase(object):
return metrics return metrics




class Metric(MetricBase):
class FuncMetric(MetricBase):
def __init__(self, func, key_map, **kwargs): def __init__(self, func, key_map, **kwargs):
super().__init__() super().__init__()

_check_function_or_method(func=func)
self._init_param_map(key_map=key_map, **kwargs)

self.evaluate = func


class AccuracyMetric(MetricBase):
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None):
super().__init__()

self._init_param_map(predictions=predictions, targets=targets,
masks=masks, seq_lens=seq_lens)

def evaluate(self, predictions, targets, masks=None, seq_lens=None):
"""

:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([]), torch.Size([n_classes,]), torch.Size([max_len,]), torch.Size([max_len, n_classes])
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([]), torch.Size([]), torch.Size([max_len,]), torch.Size([max_len, ])
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([max_len,], torch.Size([max_len, ])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([1], torch.Size([1])
:return: dict({'acc': float})
"""
pass pass


def _check_evaluate_param(self, predictions, targets, masks=None, seq_lens=None):
# check the validity of self.evaluate param
prediction = predictions[0]
target = targets[0]

if len(np.shape(prediction))==len(target):
pass

if masks is not None:
mask = masks[0]
if seq_lens is not None:
seq_len = seq_lens[0]





def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """




+ 3
- 3
fastNLP/core/tester.py View File

@@ -7,11 +7,11 @@ from torch import nn
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate


class Tester(object): class Tester(object):
@@ -57,7 +57,7 @@ class Tester(object):


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self._predict_func, batch_x) prediction = self._data_forward(self._predict_func, batch_x)
assert isinstance(prediction, dict) assert isinstance(prediction, dict)
for k, v in prediction.items(): for k, v in prediction.items():
@@ -77,7 +77,7 @@ class Tester(object):
except CheckError as e: except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func) prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths)
check_res=e.check_res, output=output, batch_y=truths, check_level=0)




if self.verbose >= 0: if self.verbose >= 0:


+ 16
- 17
fastNLP/core/trainer.py View File

@@ -1,6 +1,5 @@
import os import os
import time import time
import warnings
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta


@@ -9,24 +8,19 @@ from tensorboardX import SummaryWriter
from torch import nn from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop
@@ -52,6 +46,9 @@ class Trainer(object):
if metrics and (dev_data is None): if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")


# check save_path
if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.")
# prepare evaluate # prepare evaluate
metrics = _prepare_metrics(metrics) metrics = _prepare_metrics(metrics)


@@ -156,7 +153,7 @@ class Trainer(object):
""" """
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(model, batch_x) prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss) self._grad_backward(loss)
@@ -232,11 +229,12 @@ class Trainer(object):
return self.losser(predict, truth) return self.losser(predict, truth)


def _save_model(self, model, model_name, only_param=False): def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)
if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)


def _better_eval_result(self, metrics): def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.
@@ -297,7 +295,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check # forward check
if batch_count==0: if batch_count==0:
_check_forward_error(forward_func=model.forward, check_level=check_level, _check_forward_error(forward_func=model.forward, check_level=check_level,
@@ -335,6 +333,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if dev_data is not None: if dev_data is not None:
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1) batch_size=batch_size, verbose=-1)
tester.test()
evaluate_results = tester.test()
# TODO 这里需要检查是否返回来的值是否是合理的





+ 8
- 24
fastNLP/core/utils.py View File

@@ -122,13 +122,13 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys()) input_args = set(input_arg_count.keys())
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)
varargs = [] if spect.varargs else [arg for arg in spect.varargs]
return CheckRes(missing=missing, return CheckRes(missing=missing,
unused=unused, unused=unused,
duplicated=duplicated, duplicated=duplicated,
required=list(require_args), required=list(require_args),
all_needed=list(all_args), all_needed=list(all_args),
varargs=[arg for arg in spect.varargs])
varargs=varargs)


def get_func_signature(func): def get_func_signature(func):
""" """
@@ -165,6 +165,7 @@ def get_func_signature(func):
signature_str = func.__name__ + signature_str signature_str = func.__name__ + signature_str
return signature_str return signature_str



def _is_function_or_method(func): def _is_function_or_method(func):
""" """


@@ -179,26 +180,8 @@ def _check_function_or_method(func):
if not _is_function_or_method(func): if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.") raise TypeError(f"{type(func)} is not a method or function.")


def _syn_model_data(model, *args):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param model:
:param args:
:return:
"""
if len(model.state_dict())==0:
raise ValueError("model has no parameter.")
device = model.parameters().__next__().device
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")


def _move_dict_value_to_device(device, *args):
def _move_dict_value_to_device(*args, device:torch.device):
""" """


move data to model's device, element in *args should be dict. This is a inplace change. move data to model's device, element in *args should be dict. This is a inplace change.
@@ -240,6 +223,7 @@ class CheckError(Exception):
self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature self.func_signature = func_signature



IGNORE_CHECK_LEVEL = 0 IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2
@@ -252,8 +236,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)") f"please delete it.)")
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}"
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).")
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`"
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).")
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
@@ -281,7 +265,7 @@ def _check_forward_error(forward_func, batch_x, check_level):
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.")
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.")
if check_res.unused: if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"] _unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:


Loading…
Cancel
Save