Browse Source

增加metric

tags/v0.2.0^2
yh 6 years ago
parent
commit
ad0a8c1775
5 changed files with 245 additions and 87 deletions
  1. +23
    -0
      fastNLP/core/losses.py
  2. +128
    -1
      fastNLP/core/metrics.py
  3. +41
    -15
      fastNLP/core/tester.py
  4. +30
    -41
      fastNLP/core/trainer.py
  5. +23
    -30
      fastNLP/core/utils.py

+ 23
- 0
fastNLP/core/losses.py View File

@@ -17,6 +17,29 @@ class Loss(LossBase):
pass pass




class LossInForward(LossBase):
def __init__(self, loss_key='loss'):
super().__init__()

self.loss_key = loss_key

def get_loss(self, *args, **kwargs):
pass

def __call__(self, output_dict, predict_dict):
pass


def _prepare_losser(losser):
if losser is None:
losser = LossInForward()
return losser
elif isinstance(losser, LossBase):
return losser
else:
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}")


def squash(predict, truth, **kwargs): def squash(predict, truth, **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch '''To reshape tensors in order to fit Loss functions in pytorch




+ 128
- 1
fastNLP/core/metrics.py View File

@@ -1,8 +1,136 @@

import warnings import warnings
import inspect


import numpy as np import numpy as np
import torch import torch


from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

class MetricBase(object):
def __init__(self):
self.param_map = {} # key is param in function, value is input param.
self._checked = False

def evaluate(self, *args, **kwargs):
raise NotImplementedError

def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
for key, value in kwargs.items():
if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict:
:param target_dict:
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args

# need to wrap inputs in dict.
mapped_output_dict = {}
mapped_target_dict = {}
for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg]
if input_arg in output_dict:
mapped_output_dict[func_arg] = output_dict[input_arg]
if input_arg in target_dict:
mapped_target_dict[func_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if force_check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
new_value = value.copy()
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res)
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)

metrics = self.evaluate(**refined_args)

if not isinstance(metrics, dict):
raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, "
f"got {type(metrics)}.")
self._checked = True

return metrics





class CheckError(Exception):
def __init__(self, check_res):

err = ''
if check_res.missing:
err += f'Missing: {check_res.missing}\n'
if check_res.duplicated:
err += f'Duplicated: {check_res.duplicated}\n'
self.check_res = check_res

def __str__(self):
pass


class Metric(MetricBase):
def __init__(self, func, key_map, **kwargs):
super().__init__()
pass

def _prepare_metrics(metrics):
"""

Prepare list of Metric based on input
:param metrics:
:return: List[fastNLP.MetricBase]
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, MetricBase):
_metrics.append(metric)
else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}."
.format(type(metrics)))
return _metrics



class Evaluator(object): class Evaluator(object):
def __init__(self): def __init__(self):
@@ -17,7 +145,6 @@ class Evaluator(object):
""" """
raise NotImplementedError raise NotImplementedError



class ClassifyEvaluator(Evaluator): class ClassifyEvaluator(Evaluator):
def __init__(self): def __init__(self):
super(ClassifyEvaluator, self).__init__() super(ClassifyEvaluator, self).__init__()


+ 41
- 15
fastNLP/core/tester.py View File

@@ -2,32 +2,49 @@ import itertools
from collections import defaultdict from collections import defaultdict


import torch import torch
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
super(Tester, self).__init__() super(Tester, self).__init__()
self.use_cuda = use_cuda

if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")

self.metrics = _prepare_metrics(metrics)

# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model

self.data = data self.data = data
self.batch_size = batch_size
self.verbose = verbose
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda() self._model = model.cuda()
else: else:
self._model = model self._model = model
if hasattr(self._model, 'predict'):
if not callable(self._model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
self._predict_func = self._model.predict
else:
self._predict_func = self._model
self.use_cuda = use_cuda
self.batch_size = batch_size
self.verbose = verbose

self._model_device = model.parameters().__next__().device




def test(self): def test(self):
@@ -39,6 +56,7 @@ class Tester(object):


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)
assert isinstance(prediction, dict) assert isinstance(prediction, dict)
for k, v in prediction.items(): for k, v in prediction.items():
@@ -49,10 +67,13 @@ class Tester(object):
output[k] = itertools.chain(*v) output[k] = itertools.chain(*v)
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v) truths[k] = itertools.chain(*v)
# args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
eval_results = {}
for metric in self.metrics:
eval_result = metric(output, truths)
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
if self.verbose >= 0: if self.verbose >= 0:
print("[tester] {}".format(self.print_eval_results(eval_results)))
print("[tester] \n{}".format(self.format_eval_results(eval_results)))
self.mode(network, is_test=False) self.mode(network, is_test=False)
return eval_results return eval_results


@@ -74,10 +95,15 @@ class Tester(object):
y = self._predict_func(**x) y = self._predict_func(**x)
return y return y


def print_eval_results(self, results):
def format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.


:param results: dict, (str: float) is (metrics name: value) :param results: dict, (str: float) is (metrics name: value)


""" """
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + '\n\t'
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str += '\n'
return _str

+ 30
- 41
fastNLP/core/trainer.py View File

@@ -17,10 +17,15 @@ from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet


from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics



class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop
@@ -32,6 +37,25 @@ class Trainer(object):
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# prepare evaluate
metrics = _prepare_metrics(metrics)
# prepare loss
losser = _prepare_losser(losser)

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)

self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
self.model = model self.model = model
@@ -45,10 +69,7 @@ class Trainer(object):
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0 self._best_accuracy = 0



# TODO check loss与metrics的类型


self._model_device = model.parameters().__next__().device


# TODO self._best_accuracy不能表现出当前的metric多种的情况 # TODO self._best_accuracy不能表现出当前的metric多种的情况


@@ -72,38 +93,6 @@ class Trainer(object):


# print(self.__dict__) # print(self.__dict__)


def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
if not isinstance(train_data, DataSet):
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\
format(type(train_data)))
if not isinstance(model, nn.Module):
raise TypeError("The type of model must be torch.nn.Module, got {}.".\
format(type(model)))
if losser is not None:
# TODO change
if not isinstance(losser, None):
raise TypeError("The type of losser must be xxx, got {}.".\
format(type(losser)))

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check loss
if isinstance(losser, type):
self.losser = losser()
if not isinstance(self.losser, None):
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.')

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)


def train(self): def train(self):
"""Start Training. """Start Training.


@@ -153,8 +142,9 @@ class Trainer(object):
- epoch: int, - epoch: int,
""" """
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x) prediction = self.data_forward(model, batch_x)

loss = self.get_loss(prediction, batch_y) loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
@@ -205,7 +195,6 @@ class Trainer(object):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict): if not isinstance(y, dict):

raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


@@ -299,7 +288,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {} should be torch.Tensor, but got {}.".
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.".
format(type(losser), type(loss))) format(type(losser), type(loss)))
if len(loss.size())!=0: if len(loss.size())!=0:
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
@@ -314,7 +303,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
outputs, truths = defaultdict(list), defaultdict(list) outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester # TODO 这里修改为使用tester
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, )


with torch.no_grad(): with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): for batch_count, (batch_x, batch_y) in enumerate(dev_batch):


+ 23
- 30
fastNLP/core/utils.py View File

@@ -3,11 +3,9 @@ import inspect
import os import os
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple
from collections import defaultdict
import torch import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)


def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.


@@ -89,11 +87,15 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys()) input_args = set(input_arg_count.keys())
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

check_res = {}
check_res['missing'] = missing
check_res['unused'] = unused
check_res['duplicated'] = duplicated
check_res['required'] = list(require_args)
check_res['all_needed'] = list(all_args)

return check_res


def get_func_signature(func): def get_func_signature(func):
""" """
@@ -150,31 +152,22 @@ def _syn_model_data(model, *args):
else: else:
raise TypeError("Only support `dict` type right now.") raise TypeError("Only support `dict` type right now.")


def _prepare_metrics(metrics):
def _move_dict_value_to_device(device, *args):
""" """


Prepare list of Metric based on input
:param metrics:
move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param args:
:return: :return:
""" """
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, None):
_metrics.append(metric)
else:
raise TypeError("The type of metric in metrics must be xxxx, not {}.".format(
type(), type(metric)
))
elif isinstance(metrics, None):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format(
type(metrics)
))
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")


return _metrics
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")



Loading…
Cancel
Save