Browse Source

增加TransformersAccuracy

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
36174f727d
18 changed files with 94 additions and 64 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +1
    -1
      fastNLP/core/callbacks/callback_manager.py
  3. +8
    -6
      fastNLP/core/controllers/loops/evaluate_batch_loop.py
  4. +14
    -13
      fastNLP/core/controllers/loops/train_batch_loop.py
  5. +1
    -1
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  6. +1
    -1
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  7. +1
    -1
      fastNLP/core/dataloaders/prepare_dataloader.py
  8. +3
    -3
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  9. +2
    -1
      fastNLP/core/dataloaders/utils.py
  10. +24
    -18
      fastNLP/core/drivers/torch_driver/single_device.py
  11. +2
    -1
      fastNLP/core/metrics/__init__.py
  12. +24
    -6
      fastNLP/core/metrics/accuracy.py
  13. +3
    -3
      fastNLP/io/data_bundle.py
  14. +2
    -2
      fastNLP/transformers/torch/configuration_utils.py
  15. +1
    -1
      fastNLP/transformers/torch/generation_stopping_criteria.py
  16. +2
    -2
      fastNLP/transformers/torch/generation_utils.py
  17. +3
    -3
      fastNLP/transformers/torch/modeling_utils.py
  18. +1
    -1
      fastNLP/transformers/torch/models/bart/configuration_bart.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -69,6 +69,7 @@ __all__ = [
# metrics # metrics
"Metric", "Metric",
"Accuracy", "Accuracy",
"TransformersAccuracy",
'SpanFPreRecMetric', 'SpanFPreRecMetric',
'ClassifyFPreRecMetric', 'ClassifyFPreRecMetric',




+ 1
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -25,7 +25,7 @@ def _transfer(func):
for callback_fn in manager.callback_fns[func.__name__]: for callback_fn in manager.callback_fns[func.__name__]:
try: try:
callback_fn(*arg, **kwargs) callback_fn(*arg, **kwargs)
except EarlyStopException as e:
except (EarlyStopException, KeyboardInterrupt) as e:
raise e raise e
except BaseException as e: except BaseException as e:
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.")


+ 8
- 6
fastNLP/core/controllers/loops/evaluate_batch_loop.py View File

@@ -27,19 +27,21 @@ class EvaluateBatchLoop(Loop):
while True: while True:
try: try:
batch = next(iterator) batch = next(iterator)
batch = match_and_substitute_params(evaluator.input_mapping, batch)
batch = evaluator.move_data_to_device(batch)
except StopIteration: except StopIteration:
break break
try:
batch = match_and_substitute_params(evaluator.input_mapping, batch)
batch = evaluator.move_data_to_device(batch)

self.batch_step_fn(evaluator, batch)
batch_idx += 1
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)

except BaseException as e: except BaseException as e:
if callable(getattr(dataloader, 'get_batch_indices', None)): if callable(getattr(dataloader, 'get_batch_indices', None)):
indices = dataloader.get_batch_indices() indices = dataloader.get_batch_indices()
logger.error(f"Exception happens when evaluating on samples: {indices}") logger.error(f"Exception happens when evaluating on samples: {indices}")
raise e raise e

self.batch_step_fn(evaluator, batch)
batch_idx += 1
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...}
results = evaluator.get_metric() results = evaluator.get_metric()
return results return results


+ 14
- 13
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -19,30 +19,31 @@ class TrainBatchLoop(Loop):
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\
else lambda *args, **kwargs: None else lambda *args, **kwargs: None
dataloader = iter(dataloader) dataloader = iter(dataloader)
indices = None
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
try: try:
trainer.on_fetch_data_begin() trainer.on_fetch_data_begin()
batch = next(dataloader) batch = next(dataloader)
indices = get_batch_indices() indices = get_batch_indices()
except StopIteration:
break

try:
trainer.on_fetch_data_end() trainer.on_fetch_data_end()
batch = match_and_substitute_params(trainer.input_mapping, batch) batch = match_and_substitute_params(trainer.input_mapping, batch)
batch = trainer.move_data_to_device(batch) batch = trainer.move_data_to_device(batch)
except StopIteration:
break

trainer.on_train_batch_begin(batch, indices)
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
self.batch_step_fn(trainer, batch)
trainer.global_forward_batches += 1
trainer.batch_idx_in_epoch += 1

trainer.check_batch_step_fn()
trainer.on_train_batch_end()
except BaseException as e: except BaseException as e:
if indices and not isinstance(e, EarlyStopException):
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)):
logger.error(f"Exception happens when running on samples: {indices}") logger.error(f"Exception happens when running on samples: {indices}")
raise e raise e

trainer.on_train_batch_begin(batch, indices)
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
self.batch_step_fn(trainer, batch)
trainer.global_forward_batches += 1
trainer.batch_idx_in_epoch += 1

trainer.check_batch_step_fn()
trainer.on_train_batch_end()
trainer.step_evaluate() trainer.step_evaluate()
trainer.batch_idx_in_epoch = 0 trainer.batch_idx_in_epoch = 0




+ 1
- 1
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -47,7 +47,7 @@ class JittorDataLoader:
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset 提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset
""" """


def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto") -> None: collate_fn: Union[None, str, Callable] = "auto") -> None:


+ 1
- 1
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -47,7 +47,7 @@ class PaddleDataLoader(DataLoader):


def __init__(self, dataset, feed_list=None, places=None, def __init__(self, dataset, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None, return_list: bool = True, batch_sampler=None,
batch_size: int = 1, shuffle: bool = False,
batch_size: int = 1, shuffle: bool = True,
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
num_workers: int = 0, use_buffer_reader: bool = True, num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0, use_shared_memory: bool = True, timeout: int = 0,


+ 1
- 1
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger from ..log import logger




def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False,
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False,
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
seed: int = 0, backend: str = 'auto'): seed: int = 0, backend: str = 'auto'):
""" """


+ 3
- 3
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -179,7 +179,7 @@ class TorchDataLoader(DataLoader):


def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1, batch_size: int = 1,
shuffle: bool = False,
shuffle: bool = True,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto',
@@ -236,8 +236,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping
persistent_workers=persistent_workers, persistent_workers=persistent_workers,
) )
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler,
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,


+ 2
- 1
fastNLP/core/dataloaders/utils.py View File

@@ -1,9 +1,10 @@
from typing import Callable
__all__ = [ __all__ = [
"indice_collate_wrapper" "indice_collate_wrapper"
] ]




def indice_collate_wrapper(func):
def indice_collate_wrapper(func:Callable):
""" """
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。




+ 24
- 18
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -1,11 +1,13 @@
import os import os
from typing import Dict, Union, Callable, Tuple, Optional from typing import Dict, Union, Callable, Tuple, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
from torch.nn import DataParallel from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import RandomSampler as TorchRandomSampler from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler


__all__ = [ __all__ = [
'TorchSingleDriver' 'TorchSingleDriver'
@@ -15,7 +17,8 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -24,6 +27,7 @@ class TorchSingleDriver(TorchDriver):
r""" r"""
用于 cpu 和 单卡 gpu 运算; 用于 cpu 和 单卡 gpu 运算;
""" """

def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs):
if isinstance(model, DistributedDataParallel): if isinstance(model, DistributedDataParallel):
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`")
@@ -88,7 +92,8 @@ class TorchSingleDriver(TorchDriver):
else: else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False): reproducible: bool = False):


# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
@@ -108,17 +113,24 @@ class TorchSingleDriver(TorchDriver):


if reproducible: if reproducible:
if isinstance(args.sampler, TorchRandomSampler): if isinstance(args.sampler, TorchRandomSampler):
# 如果本来就是随机的,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, TorchSequentialSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else:
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else: else:
return dataloader return dataloader


@@ -138,9 +150,3 @@ class TorchSingleDriver(TorchDriver):


def is_distributed(self): def is_distributed(self):
return False return False







+ 2
- 1
fastNLP/core/metrics/__init__.py View File

@@ -1,11 +1,12 @@
__all__ = [ __all__ = [
"Metric", "Metric",
"Accuracy", "Accuracy",
"TransformersAccuracy",
'SpanFPreRecMetric', 'SpanFPreRecMetric',
'ClassifyFPreRecMetric', 'ClassifyFPreRecMetric',
] ]


from .metric import Metric from .metric import Metric
from .accuracy import Accuracy
from .accuracy import Accuracy, TransformersAccuracy
from .span_f1_pre_rec_metric import SpanFPreRecMetric from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric

+ 24
- 6
fastNLP/core/metrics/accuracy.py View File

@@ -1,5 +1,6 @@
__all__ = [ __all__ = [
'Accuracy'
'Accuracy',
"TransformersAccuracy"
] ]


from typing import Union from typing import Union
@@ -17,9 +18,9 @@ class Accuracy(Metric):
""" """
计算 准确率 的 metric 。 计算 准确率 的 metric 。


:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
:param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
""" """
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
@@ -39,11 +40,11 @@ class Accuracy(Metric):
r""" r"""
update 函数将针对一个批次的预测结果做评价指标的累计 update 函数将针对一个批次的预测结果做评价指标的累计


:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
:param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
:param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
:param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略. 如果mask也被传进来的话seq_len会被忽略.
""" """
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。
@@ -79,3 +80,20 @@ class Accuracy(Metric):
else: else:
self.total += np.prod(list(pred.shape)).item() self.total += np.prod(list(pred.shape)).item()
self.correct += (target == pred).sum().item() self.correct += (target == pred).sum().item()


class TransformersAccuracy(Accuracy):
"""
适配 transformers 中相关模型的 Accuracy metric 。

"""
def update(self, logits, labels, attention_mask=None):
r"""
update 函数将针对一个批次的预测结果做评价指标的累计

:param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。
:param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]``
:param attention_mask: 序列长度标记。
"""
seq_len = attention_mask.sum(dim=-1)
super().update(pred=logits, target=labels, seq_len=seq_len)

+ 3
- 3
fastNLP/io/data_bundle.py View File

@@ -249,7 +249,7 @@ class DataBundle:
return self return self


def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True):
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法


@@ -263,8 +263,8 @@ class DataBundle:
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错 如果为False,则报错
:param show_progress_bar: 是否显示tqdm进度条
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
:param show_progress_bar: 是否显示进度条
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。


:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字




+ 2
- 2
fastNLP/transformers/torch/configuration_utils.py View File

@@ -314,7 +314,7 @@ class PretrainedConfig:


# TPU arguments # TPU arguments
if kwargs.pop("xla_device", None) is not None: if kwargs.pop("xla_device", None) is not None:
logger.warning(
logger.rank_zero_warning(
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
"safely remove it from your `config.json` file." "safely remove it from your `config.json` file."
) )
@@ -474,7 +474,7 @@ class PretrainedConfig:
""" """
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warn(
logger.rank_zero_warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
) )


+ 1
- 1
fastNLP/transformers/torch/generation_stopping_criteria.py View File

@@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng
stopping_max_length = stopping_criteria.max_length stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria) new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length: if stopping_max_length is not None and stopping_max_length != max_length:
logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None: elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria return new_stopping_criteria

+ 2
- 2
fastNLP/transformers/torch/generation_utils.py View File

@@ -429,7 +429,7 @@ class GenerationMixin:


def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id pad_token_id = eos_token_id
return pad_token_id return pad_token_id


@@ -912,7 +912,7 @@ class GenerationMixin:


# special case if pad_token_id is not defined # special case if pad_token_id is not defined
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id pad_token_id = eos_token_id


# Storing encoder_input_ids for logits_processor that could use them # Storing encoder_input_ids for logits_processor that could use them


+ 3
- 3
fastNLP/transformers/torch/modeling_utils.py View File

@@ -352,7 +352,7 @@ class ModuleUtilsMixin:
if token_inputs: if token_inputs:
return sum([token_input.numel() for token_input in token_inputs]) return sum([token_input.numel() for token_input in token_inputs])
else: else:
logger.warn(
logger.rank_zero_warning(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
) )
return 0 return 0
@@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin):
# tie weights recursively # tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0: if len(uninitialized_encoder_weights) > 0:
logger.warning(
logger.rank_zero_warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
) )


@@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")


if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning(
logger.rank_zero_warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "


+ 1
- 1
fastNLP/transformers/torch/models/bart/configuration_bart.py View File

@@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig):
# ensure backward compatibility for BART CNN models # ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id self.forced_bos_token_id = self.bos_token_id
logger.warn(
logger.rank_zero_warning(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
"The config can simply be saved and uploaded again to be fixed." "The config can simply be saved and uploaded again to be fixed."
) )

Loading…
Cancel
Save