Browse Source

增加TransformersAccuracy

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
36174f727d
18 changed files with 94 additions and 64 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +1
    -1
      fastNLP/core/callbacks/callback_manager.py
  3. +8
    -6
      fastNLP/core/controllers/loops/evaluate_batch_loop.py
  4. +14
    -13
      fastNLP/core/controllers/loops/train_batch_loop.py
  5. +1
    -1
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  6. +1
    -1
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  7. +1
    -1
      fastNLP/core/dataloaders/prepare_dataloader.py
  8. +3
    -3
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  9. +2
    -1
      fastNLP/core/dataloaders/utils.py
  10. +24
    -18
      fastNLP/core/drivers/torch_driver/single_device.py
  11. +2
    -1
      fastNLP/core/metrics/__init__.py
  12. +24
    -6
      fastNLP/core/metrics/accuracy.py
  13. +3
    -3
      fastNLP/io/data_bundle.py
  14. +2
    -2
      fastNLP/transformers/torch/configuration_utils.py
  15. +1
    -1
      fastNLP/transformers/torch/generation_stopping_criteria.py
  16. +2
    -2
      fastNLP/transformers/torch/generation_utils.py
  17. +3
    -3
      fastNLP/transformers/torch/modeling_utils.py
  18. +1
    -1
      fastNLP/transformers/torch/models/bart/configuration_bart.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -69,6 +69,7 @@ __all__ = [
# metrics
"Metric",
"Accuracy",
"TransformersAccuracy",
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',



+ 1
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -25,7 +25,7 @@ def _transfer(func):
for callback_fn in manager.callback_fns[func.__name__]:
try:
callback_fn(*arg, **kwargs)
except EarlyStopException as e:
except (EarlyStopException, KeyboardInterrupt) as e:
raise e
except BaseException as e:
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.")


+ 8
- 6
fastNLP/core/controllers/loops/evaluate_batch_loop.py View File

@@ -27,19 +27,21 @@ class EvaluateBatchLoop(Loop):
while True:
try:
batch = next(iterator)
batch = match_and_substitute_params(evaluator.input_mapping, batch)
batch = evaluator.move_data_to_device(batch)
except StopIteration:
break
try:
batch = match_and_substitute_params(evaluator.input_mapping, batch)
batch = evaluator.move_data_to_device(batch)

self.batch_step_fn(evaluator, batch)
batch_idx += 1
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)

except BaseException as e:
if callable(getattr(dataloader, 'get_batch_indices', None)):
indices = dataloader.get_batch_indices()
logger.error(f"Exception happens when evaluating on samples: {indices}")
raise e

self.batch_step_fn(evaluator, batch)
batch_idx += 1
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...}
results = evaluator.get_metric()
return results


+ 14
- 13
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -19,30 +19,31 @@ class TrainBatchLoop(Loop):
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\
else lambda *args, **kwargs: None
dataloader = iter(dataloader)
indices = None
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
try:
trainer.on_fetch_data_begin()
batch = next(dataloader)
indices = get_batch_indices()
except StopIteration:
break

try:
trainer.on_fetch_data_end()
batch = match_and_substitute_params(trainer.input_mapping, batch)
batch = trainer.move_data_to_device(batch)
except StopIteration:
break

trainer.on_train_batch_begin(batch, indices)
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
self.batch_step_fn(trainer, batch)
trainer.global_forward_batches += 1
trainer.batch_idx_in_epoch += 1

trainer.check_batch_step_fn()
trainer.on_train_batch_end()
except BaseException as e:
if indices and not isinstance(e, EarlyStopException):
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)):
logger.error(f"Exception happens when running on samples: {indices}")
raise e

trainer.on_train_batch_begin(batch, indices)
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
self.batch_step_fn(trainer, batch)
trainer.global_forward_batches += 1
trainer.batch_idx_in_epoch += 1

trainer.check_batch_step_fn()
trainer.on_train_batch_end()
trainer.step_evaluate()
trainer.batch_idx_in_epoch = 0



+ 1
- 1
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -47,7 +47,7 @@ class JittorDataLoader:
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset
"""

def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto") -> None:


+ 1
- 1
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -47,7 +47,7 @@ class PaddleDataLoader(DataLoader):

def __init__(self, dataset, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
batch_size: int = 1, shuffle: bool = False,
batch_size: int = 1, shuffle: bool = True,
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,


+ 1
- 1
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger


def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False,
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False,
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
seed: int = 0, backend: str = 'auto'):
"""


+ 3
- 3
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -179,7 +179,7 @@ class TorchDataLoader(DataLoader):

def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False,
shuffle: bool = True,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto',
@@ -236,8 +236,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping
persistent_workers=persistent_workers,
)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler,
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,


+ 2
- 1
fastNLP/core/dataloaders/utils.py View File

@@ -1,9 +1,10 @@
from typing import Callable
__all__ = [
"indice_collate_wrapper"
]


def indice_collate_wrapper(func):
def indice_collate_wrapper(func:Callable):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。



+ 24
- 18
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -1,11 +1,13 @@
import os
from typing import Dict, Union, Callable, Tuple, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler

__all__ = [
'TorchSingleDriver'
@@ -15,7 +17,8 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger

@@ -24,6 +27,7 @@ class TorchSingleDriver(TorchDriver):
r"""
用于 cpu 和 单卡 gpu 运算;
"""

def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs):
if isinstance(model, DistributedDataParallel):
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`")
@@ -88,7 +92,8 @@ class TorchSingleDriver(TorchDriver):
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False):

# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
@@ -108,17 +113,24 @@ class TorchSingleDriver(TorchDriver):

if reproducible:
if isinstance(args.sampler, TorchRandomSampler):
# 如果本来就是随机的,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, TorchSequentialSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

@@ -138,9 +150,3 @@ class TorchSingleDriver(TorchDriver):

def is_distributed(self):
return False







+ 2
- 1
fastNLP/core/metrics/__init__.py View File

@@ -1,11 +1,12 @@
__all__ = [
"Metric",
"Accuracy",
"TransformersAccuracy",
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',
]

from .metric import Metric
from .accuracy import Accuracy
from .accuracy import Accuracy, TransformersAccuracy
from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric

+ 24
- 6
fastNLP/core/metrics/accuracy.py View File

@@ -1,5 +1,6 @@
__all__ = [
'Accuracy'
'Accuracy',
"TransformersAccuracy"
]

from typing import Union
@@ -17,9 +18,9 @@ class Accuracy(Metric):
"""
计算 准确率 的 metric 。

:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
:param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
"""
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
@@ -39,11 +40,11 @@ class Accuracy(Metric):
r"""
update 函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
:param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
:param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
:param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略.
"""
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。
@@ -79,3 +80,20 @@ class Accuracy(Metric):
else:
self.total += np.prod(list(pred.shape)).item()
self.correct += (target == pred).sum().item()


class TransformersAccuracy(Accuracy):
"""
适配 transformers 中相关模型的 Accuracy metric 。

"""
def update(self, logits, labels, attention_mask=None):
r"""
update 函数将针对一个批次的预测结果做评价指标的累计

:param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。
:param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]``
:param attention_mask: 序列长度标记。
"""
seq_len = attention_mask.sum(dim=-1)
super().update(pred=logits, target=labels, seq_len=seq_len)

+ 3
- 3
fastNLP/io/data_bundle.py View File

@@ -249,7 +249,7 @@ class DataBundle:
return self

def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True):
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法

@@ -263,8 +263,8 @@ class DataBundle:
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param show_progress_bar: 是否显示tqdm进度条
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
:param show_progress_bar: 是否显示进度条
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。

:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字



+ 2
- 2
fastNLP/transformers/torch/configuration_utils.py View File

@@ -314,7 +314,7 @@ class PretrainedConfig:

# TPU arguments
if kwargs.pop("xla_device", None) is not None:
logger.warning(
logger.rank_zero_warning(
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
"safely remove it from your `config.json` file."
)
@@ -474,7 +474,7 @@ class PretrainedConfig:
"""
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warn(
logger.rank_zero_warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)


+ 1
- 1
fastNLP/transformers/torch/generation_stopping_criteria.py View File

@@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria

+ 2
- 2
fastNLP/transformers/torch/generation_utils.py View File

@@ -429,7 +429,7 @@ class GenerationMixin:

def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id

@@ -912,7 +912,7 @@ class GenerationMixin:

# special case if pad_token_id is not defined
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id

# Storing encoder_input_ids for logits_processor that could use them


+ 3
- 3
fastNLP/transformers/torch/modeling_utils.py View File

@@ -352,7 +352,7 @@ class ModuleUtilsMixin:
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
logger.warn(
logger.rank_zero_warning(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
@@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin):
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
logger.rank_zero_warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)

@@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

if len(unexpected_keys) > 0:
logger.warning(
logger.rank_zero_warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "


+ 1
- 1
fastNLP/transformers/torch/models/bart/configuration_bart.py View File

@@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig):
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
logger.warn(
logger.rank_zero_warning(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
"The config can simply be saved and uploaded again to be fixed."
)

Loading…
Cancel
Save