Browse Source

change the calculation of metric to batch by batch. The older design is to concat all data before calculation.

tags/v0.2.0^2
yh 6 years ago
parent
commit
f24fca1b21
7 changed files with 133 additions and 66 deletions
  1. +9
    -1
      fastNLP/core/batch.py
  2. +1
    -1
      fastNLP/core/fieldarray.py
  3. +18
    -3
      fastNLP/core/losses.py
  4. +78
    -17
      fastNLP/core/metrics.py
  5. +3
    -3
      fastNLP/core/tester.py
  6. +16
    -17
      fastNLP/core/trainer.py
  7. +8
    -24
      fastNLP/core/utils.py

+ 9
- 1
fastNLP/core/batch.py View File

@@ -1,4 +1,5 @@
import torch
import numpy as np


class Batch(object):
@@ -45,7 +46,7 @@ class Batch(object):
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy:
batch = torch.from_numpy(batch)
batch = to_tensor(batch, field.dtype)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
@@ -54,3 +55,10 @@ class Batch(object):
self.curidx = endidx

return batch_x, batch_y

def to_tensor(batch, dtype):
if dtype in (np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (np.float32, np.float64):
batch = torch.FloatTensor(batch)
return batch

+ 1
- 1
fastNLP/core/fieldarray.py View File

@@ -39,7 +39,7 @@ class FieldArray(object):

@staticmethod
def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.double, str: np.str}
type_mapping = {int: np.int64, float: np.float64, str: np.str}
return type_mapping[basic_type]

def __repr__(self):


+ 18
- 3
fastNLP/core/losses.py View File

@@ -126,15 +126,30 @@ class NLLLoss(LossBase):
class LossInForward(LossBase):
def __init__(self, loss_key='loss'):
super().__init__()
if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
self.loss_key = loss_key

def get_loss(self, **kwargs):
if self.loss_key not in kwargs:
pass
check_res = CheckRes(missing=[self.loss_key],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))

def __call__(self, output_dict, predict_dict):
def __call__(self, output_dict, predict_dict, force_check=False):

return self.get_loss(**output_dict)
loss = self.get_loss(**output_dict)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")

return loss


def _prepare_losser(losser):


+ 78
- 17
fastNLP/core/metrics.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_function_or_method

class MetricBase(object):
def __init__(self):
@@ -20,19 +20,32 @@ class MetricBase(object):
def evaluate(self, *args, **kwargs):
raise NotImplementedError

def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
value_counter = defaultdict(0)
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def _init_param_map(self, key_map=None, **kwargs):
value_counter = defaultdict(set)
if key_map is not None:
if not isinstance(key_map, dict):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
continue
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
continue
if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set)>1:
raise ValueError(f"Several params:{key_set} are provided with one output {value}.")

def __call__(self, output_dict, target_dict, check=False):
"""
@@ -45,8 +58,6 @@ class MetricBase(object):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not self._checked:
# 0. check param_map does not have same value

# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
@@ -58,26 +69,32 @@ class MetricBase(object):
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {value: key for key, value in self.param_map.items()}

# need to wrap inputs in dict.
mapped_output_dict = {}
mapped_target_dict = {}
for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg]
if input_arg in self._reverse_param_map:
mapped_arg = func_arg
else:
mapped_arg = input_arg
if input_arg in output_dict:
mapped_output_dict[func_arg] = output_dict[input_arg]
mapped_output_dict[mapped_arg] = output_dict[input_arg]
if input_arg in target_dict:
mapped_target_dict[func_arg] = target_dict[input_arg]
mapped_target_dict[mapped_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
new_value = list(value)
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})'
else:
new_value[idx] = func_param
if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
@@ -93,11 +110,55 @@ class MetricBase(object):
return metrics


class Metric(MetricBase):
class FuncMetric(MetricBase):
def __init__(self, func, key_map, **kwargs):
super().__init__()

_check_function_or_method(func=func)
self._init_param_map(key_map=key_map, **kwargs)

self.evaluate = func


class AccuracyMetric(MetricBase):
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None):
super().__init__()

self._init_param_map(predictions=predictions, targets=targets,
masks=masks, seq_lens=seq_lens)

def evaluate(self, predictions, targets, masks=None, seq_lens=None):
"""

:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([]), torch.Size([n_classes,]), torch.Size([max_len,]), torch.Size([max_len, n_classes])
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([]), torch.Size([]), torch.Size([max_len,]), torch.Size([max_len, ])
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([max_len,], torch.Size([max_len, ])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([1], torch.Size([1])
:return: dict({'acc': float})
"""
pass

def _check_evaluate_param(self, predictions, targets, masks=None, seq_lens=None):
# check the validity of self.evaluate param
prediction = predictions[0]
target = targets[0]

if len(np.shape(prediction))==len(target):
pass

if masks is not None:
mask = masks[0]
if seq_lens is not None:
seq_len = seq_lens[0]





def _prepare_metrics(metrics):
"""



+ 3
- 3
fastNLP/core/tester.py View File

@@ -7,11 +7,11 @@ from torch import nn
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate

class Tester(object):
@@ -57,7 +57,7 @@ class Tester(object):

with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self._predict_func, batch_x)
assert isinstance(prediction, dict)
for k, v in prediction.items():
@@ -77,7 +77,7 @@ class Tester(object):
except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths)
check_res=e.check_res, output=output, batch_y=truths, check_level=0)


if self.verbose >= 0:


+ 16
- 17
fastNLP/core/trainer.py View File

@@ -1,6 +1,5 @@
import os
import time
import warnings
from datetime import datetime
from datetime import timedelta

@@ -9,24 +8,19 @@ from tensorboardX import SummaryWriter
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature

class Trainer(object):
"""Main Training Loop
@@ -52,6 +46,9 @@ class Trainer(object):
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check save_path
if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.")
# prepare evaluate
metrics = _prepare_metrics(metrics)

@@ -156,7 +153,7 @@ class Trainer(object):
"""
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
@@ -232,11 +229,12 @@ class Trainer(object):
return self.losser(predict, truth)

def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)
if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)

def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.
@@ -297,7 +295,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_

batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check
if batch_count==0:
_check_forward_error(forward_func=model.forward, check_level=check_level,
@@ -335,6 +333,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if dev_data is not None:
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
tester.test()
evaluate_results = tester.test()
# TODO 这里需要检查是否返回来的值是否是合理的



+ 8
- 24
fastNLP/core/utils.py View File

@@ -122,13 +122,13 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
varargs = [] if spect.varargs else [arg for arg in spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=[arg for arg in spect.varargs])
varargs=varargs)

def get_func_signature(func):
"""
@@ -165,6 +165,7 @@ def get_func_signature(func):
signature_str = func.__name__ + signature_str
return signature_str


def _is_function_or_method(func):
"""

@@ -179,26 +180,8 @@ def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")

def _syn_model_data(model, *args):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param model:
:param args:
:return:
"""
if len(model.state_dict())==0:
raise ValueError("model has no parameter.")
device = model.parameters().__next__().device
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")

def _move_dict_value_to_device(device, *args):
def _move_dict_value_to_device(*args, device:torch.device):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
@@ -240,6 +223,7 @@ class CheckError(Exception):
self.check_res = check_res
self.func_signature = func_signature


IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2
@@ -252,8 +236,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}"
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).")
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`"
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
@@ -281,7 +265,7 @@ def _check_forward_error(forward_func, batch_x, check_level):
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.")
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:


Loading…
Cancel
Save