Browse Source

增加metric

tags/v0.2.0^2
yh 6 years ago
parent
commit
ad0a8c1775
5 changed files with 245 additions and 87 deletions
  1. +23
    -0
      fastNLP/core/losses.py
  2. +128
    -1
      fastNLP/core/metrics.py
  3. +41
    -15
      fastNLP/core/tester.py
  4. +30
    -41
      fastNLP/core/trainer.py
  5. +23
    -30
      fastNLP/core/utils.py

+ 23
- 0
fastNLP/core/losses.py View File

@@ -17,6 +17,29 @@ class Loss(LossBase):
pass


class LossInForward(LossBase):
def __init__(self, loss_key='loss'):
super().__init__()

self.loss_key = loss_key

def get_loss(self, *args, **kwargs):
pass

def __call__(self, output_dict, predict_dict):
pass


def _prepare_losser(losser):
if losser is None:
losser = LossInForward()
return losser
elif isinstance(losser, LossBase):
return losser
else:
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}")


def squash(predict, truth, **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch



+ 128
- 1
fastNLP/core/metrics.py View File

@@ -1,8 +1,136 @@

import warnings
import inspect

import numpy as np
import torch

from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

class MetricBase(object):
def __init__(self):
self.param_map = {} # key is param in function, value is input param.
self._checked = False

def evaluate(self, *args, **kwargs):
raise NotImplementedError

def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
for key, value in kwargs.items():
if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict:
:param target_dict:
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args

# need to wrap inputs in dict.
mapped_output_dict = {}
mapped_target_dict = {}
for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg]
if input_arg in output_dict:
mapped_output_dict[func_arg] = output_dict[input_arg]
if input_arg in target_dict:
mapped_target_dict[func_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if force_check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
new_value = value.copy()
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res)
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)

metrics = self.evaluate(**refined_args)

if not isinstance(metrics, dict):
raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, "
f"got {type(metrics)}.")
self._checked = True

return metrics





class CheckError(Exception):
def __init__(self, check_res):

err = ''
if check_res.missing:
err += f'Missing: {check_res.missing}\n'
if check_res.duplicated:
err += f'Duplicated: {check_res.duplicated}\n'
self.check_res = check_res

def __str__(self):
pass


class Metric(MetricBase):
def __init__(self, func, key_map, **kwargs):
super().__init__()
pass

def _prepare_metrics(metrics):
"""

Prepare list of Metric based on input
:param metrics:
:return: List[fastNLP.MetricBase]
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, MetricBase):
_metrics.append(metric)
else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}."
.format(type(metrics)))
return _metrics


class Evaluator(object):
def __init__(self):
@@ -17,7 +145,6 @@ class Evaluator(object):
"""
raise NotImplementedError


class ClassifyEvaluator(Evaluator):
def __init__(self):
super(ClassifyEvaluator, self).__init__()


+ 41
- 15
fastNLP/core/tester.py View File

@@ -2,32 +2,49 @@ import itertools
from collections import defaultdict

import torch
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics

class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
super(Tester, self).__init__()
self.use_cuda = use_cuda

if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")

self.metrics = _prepare_metrics(metrics)

# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model

self.data = data
self.batch_size = batch_size
self.verbose = verbose
if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda()
else:
self._model = model
if hasattr(self._model, 'predict'):
if not callable(self._model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
self._predict_func = self._model.predict
else:
self._predict_func = self._model
self.use_cuda = use_cuda
self.batch_size = batch_size
self.verbose = verbose

self._model_device = model.parameters().__next__().device


def test(self):
@@ -39,6 +56,7 @@ class Tester(object):

with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(network, batch_x)
assert isinstance(prediction, dict)
for k, v in prediction.items():
@@ -49,10 +67,13 @@ class Tester(object):
output[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
# args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
eval_results = {}
for metric in self.metrics:
eval_result = metric(output, truths)
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
if self.verbose >= 0:
print("[tester] {}".format(self.print_eval_results(eval_results)))
print("[tester] \n{}".format(self.format_eval_results(eval_results)))
self.mode(network, is_test=False)
return eval_results

@@ -74,10 +95,15 @@ class Tester(object):
y = self._predict_func(**x)
return y

def print_eval_results(self, results):
def format_eval_results(self, results):
"""Override this method to support more print formats.

:param results: dict, (str: float) is (metrics name: value)

"""
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + '\n\t'
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str += '\n'
return _str

+ 30
- 41
fastNLP/core/trainer.py View File

@@ -17,10 +17,15 @@ from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics


class Trainer(object):
"""Main Training Loop
@@ -32,6 +37,25 @@ class Trainer(object):
**kwargs):
super(Trainer, self).__init__()

if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# prepare evaluate
metrics = _prepare_metrics(metrics)
# prepare loss
losser = _prepare_losser(losser)

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)

self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
self.model = model
@@ -45,10 +69,7 @@ class Trainer(object):
self.validate_every = int(validate_every)
self._best_accuracy = 0


# TODO check loss与metrics的类型


self._model_device = model.parameters().__next__().device

# TODO self._best_accuracy不能表现出当前的metric多种的情况

@@ -72,38 +93,6 @@ class Trainer(object):

# print(self.__dict__)

def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
if not isinstance(train_data, DataSet):
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\
format(type(train_data)))
if not isinstance(model, nn.Module):
raise TypeError("The type of model must be torch.nn.Module, got {}.".\
format(type(model)))
if losser is not None:
# TODO change
if not isinstance(losser, None):
raise TypeError("The type of losser must be xxx, got {}.".\
format(type(losser)))

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check loss
if isinstance(losser, type):
self.losser = losser()
if not isinstance(self.losser, None):
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.')

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)


def train(self):
"""Start Training.

@@ -153,8 +142,9 @@ class Trainer(object):
- epoch: int,
"""
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
@@ -205,7 +195,6 @@ class Trainer(object):
x = _build_args(network.forward, **x)
y = network(**x)
if not isinstance(y, dict):

raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y

@@ -299,7 +288,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {} should be torch.Tensor, but got {}.".
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.".
format(type(losser), type(loss)))
if len(loss.size())!=0:
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
@@ -314,7 +303,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, )

with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):


+ 23
- 30
fastNLP/core/utils.py View File

@@ -3,11 +3,9 @@ import inspect
import os
from collections import Counter
from collections import namedtuple
from collections import defaultdict
import torch

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)


def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.

@@ -89,11 +87,15 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

check_res = {}
check_res['missing'] = missing
check_res['unused'] = unused
check_res['duplicated'] = duplicated
check_res['required'] = list(require_args)
check_res['all_needed'] = list(all_args)

return check_res

def get_func_signature(func):
"""
@@ -150,31 +152,22 @@ def _syn_model_data(model, *args):
else:
raise TypeError("Only support `dict` type right now.")

def _prepare_metrics(metrics):
def _move_dict_value_to_device(device, *args):
"""

Prepare list of Metric based on input
:param metrics:
move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param args:
:return:
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, None):
_metrics.append(metric)
else:
raise TypeError("The type of metric in metrics must be xxxx, not {}.".format(
type(), type(metric)
))
elif isinstance(metrics, None):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format(
type(metrics)
))
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")

return _metrics
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")


Loading…
Cancel
Save