@@ -69,6 +69,7 @@ __all__ = [ | |||||
# metrics | # metrics | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
@@ -25,7 +25,7 @@ def _transfer(func): | |||||
for callback_fn in manager.callback_fns[func.__name__]: | for callback_fn in manager.callback_fns[func.__name__]: | ||||
try: | try: | ||||
callback_fn(*arg, **kwargs) | callback_fn(*arg, **kwargs) | ||||
except EarlyStopException as e: | |||||
except (EarlyStopException, KeyboardInterrupt) as e: | |||||
raise e | raise e | ||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | ||||
@@ -27,19 +27,21 @@ class EvaluateBatchLoop(Loop): | |||||
while True: | while True: | ||||
try: | try: | ||||
batch = next(iterator) | batch = next(iterator) | ||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
try: | |||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
except BaseException as e: | except BaseException as e: | ||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | if callable(getattr(dataloader, 'get_batch_indices', None)): | ||||
indices = dataloader.get_batch_indices() | indices = dataloader.get_batch_indices() | ||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | logger.error(f"Exception happens when evaluating on samples: {indices}") | ||||
raise e | raise e | ||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | ||||
results = evaluator.get_metric() | results = evaluator.get_metric() | ||||
return results | return results | ||||
@@ -19,30 +19,31 @@ class TrainBatchLoop(Loop): | |||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | ||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
indices = None | |||||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | ||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
except StopIteration: | |||||
break | |||||
try: | |||||
trainer.on_fetch_data_end() | trainer.on_fetch_data_end() | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
except StopIteration: | |||||
break | |||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
except BaseException as e: | except BaseException as e: | ||||
if indices and not isinstance(e, EarlyStopException): | |||||
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): | |||||
logger.error(f"Exception happens when running on samples: {indices}") | logger.error(f"Exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
trainer.step_evaluate() | trainer.step_evaluate() | ||||
trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
@@ -47,7 +47,7 @@ class JittorDataLoader: | |||||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | 提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | collate_fn: Union[None, str, Callable] = "auto") -> None: | ||||
@@ -47,7 +47,7 @@ class PaddleDataLoader(DataLoader): | |||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | return_list: bool = True, batch_sampler=None, | ||||
batch_size: int = 1, shuffle: bool = False, | |||||
batch_size: int = 1, shuffle: bool = True, | |||||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | ||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | |||||
from ..log import logger | from ..log import logger | ||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, | |||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
seed: int = 0, backend: str = 'auto'): | seed: int = 0, backend: str = 'auto'): | ||||
""" | """ | ||||
@@ -179,7 +179,7 @@ class TorchDataLoader(DataLoader): | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
shuffle: bool = False, | |||||
shuffle: bool = True, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | ||||
@@ -236,8 +236,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -1,9 +1,10 @@ | |||||
from typing import Callable | |||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | "indice_collate_wrapper" | ||||
] | ] | ||||
def indice_collate_wrapper(func): | |||||
def indice_collate_wrapper(func:Callable): | |||||
""" | """ | ||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | ||||
@@ -1,11 +1,13 @@ | |||||
import os | import os | ||||
from typing import Dict, Union, Callable, Tuple, Optional | from typing import Dict, Union, Callable, Tuple, Optional | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.nn import DataParallel | from torch.nn import DataParallel | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | from torch.utils.data import RandomSampler as TorchRandomSampler | ||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -15,7 +17,8 @@ from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | from fastNLP.core.samplers import RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -24,6 +27,7 @@ class TorchSingleDriver(TorchDriver): | |||||
r""" | r""" | ||||
用于 cpu 和 单卡 gpu 运算; | 用于 cpu 和 单卡 gpu 运算; | ||||
""" | """ | ||||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | ||||
if isinstance(model, DistributedDataParallel): | if isinstance(model, DistributedDataParallel): | ||||
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | ||||
@@ -88,7 +92,8 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | ||||
@@ -108,17 +113,24 @@ class TorchSingleDriver(TorchDriver): | |||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, TorchRandomSampler): | if isinstance(args.sampler, TorchRandomSampler): | ||||
# 如果本来就是随机的,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(args.sampler, TorchSequentialSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -138,9 +150,3 @@ class TorchSingleDriver(TorchDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
@@ -1,11 +1,12 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
] | ] | ||||
from .metric import Metric | from .metric import Metric | ||||
from .accuracy import Accuracy | |||||
from .accuracy import Accuracy, TransformersAccuracy | |||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric |
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Accuracy' | |||||
'Accuracy', | |||||
"TransformersAccuracy" | |||||
] | ] | ||||
from typing import Union | from typing import Union | ||||
@@ -17,9 +18,9 @@ class Accuracy(Metric): | |||||
""" | """ | ||||
计算 准确率 的 metric 。 | 计算 准确率 的 metric 。 | ||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
:param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | ||||
""" | """ | ||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
@@ -39,11 +40,11 @@ class Accuracy(Metric): | |||||
r""" | r""" | ||||
update 函数将针对一个批次的预测结果做评价指标的累计 | update 函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
:param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
:param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | ||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
:param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
如果mask也被传进来的话seq_len会被忽略. | 如果mask也被传进来的话seq_len会被忽略. | ||||
""" | """ | ||||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | ||||
@@ -79,3 +80,20 @@ class Accuracy(Metric): | |||||
else: | else: | ||||
self.total += np.prod(list(pred.shape)).item() | self.total += np.prod(list(pred.shape)).item() | ||||
self.correct += (target == pred).sum().item() | self.correct += (target == pred).sum().item() | ||||
class TransformersAccuracy(Accuracy): | |||||
""" | |||||
适配 transformers 中相关模型的 Accuracy metric 。 | |||||
""" | |||||
def update(self, logits, labels, attention_mask=None): | |||||
r""" | |||||
update 函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。 | |||||
:param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]`` | |||||
:param attention_mask: 序列长度标记。 | |||||
""" | |||||
seq_len = attention_mask.sum(dim=-1) | |||||
super().update(pred=logits, target=labels, seq_len=seq_len) |
@@ -249,7 +249,7 @@ class DataBundle: | |||||
return self | return self | ||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||||
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | ||||
@@ -263,8 +263,8 @@ class DataBundle: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param show_progress_bar: 是否显示tqdm进度条 | |||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar: 是否显示进度条 | |||||
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
@@ -314,7 +314,7 @@ class PretrainedConfig: | |||||
# TPU arguments | # TPU arguments | ||||
if kwargs.pop("xla_device", None) is not None: | if kwargs.pop("xla_device", None) is not None: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | ||||
"safely remove it from your `config.json` file." | "safely remove it from your `config.json` file." | ||||
) | ) | ||||
@@ -474,7 +474,7 @@ class PretrainedConfig: | |||||
""" | """ | ||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||||
) | ) | ||||
@@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng | |||||
stopping_max_length = stopping_criteria.max_length | stopping_max_length = stopping_criteria.max_length | ||||
new_stopping_criteria = deepcopy(stopping_criteria) | new_stopping_criteria = deepcopy(stopping_criteria) | ||||
if stopping_max_length is not None and stopping_max_length != max_length: | if stopping_max_length is not None and stopping_max_length != max_length: | ||||
logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
elif stopping_max_length is None: | elif stopping_max_length is None: | ||||
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | ||||
return new_stopping_criteria | return new_stopping_criteria |
@@ -429,7 +429,7 @@ class GenerationMixin: | |||||
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
return pad_token_id | return pad_token_id | ||||
@@ -912,7 +912,7 @@ class GenerationMixin: | |||||
# special case if pad_token_id is not defined | # special case if pad_token_id is not defined | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
# Storing encoder_input_ids for logits_processor that could use them | # Storing encoder_input_ids for logits_processor that could use them | ||||
@@ -352,7 +352,7 @@ class ModuleUtilsMixin: | |||||
if token_inputs: | if token_inputs: | ||||
return sum([token_input.numel() for token_input in token_inputs]) | return sum([token_input.numel() for token_input in token_inputs]) | ||||
else: | else: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" | "Could not estimate the number of tokens of the input, floating-point operations will not be computed" | ||||
) | ) | ||||
return 0 | return 0 | ||||
@@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
# tie weights recursively | # tie weights recursively | ||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | ||||
if len(uninitialized_encoder_weights) > 0: | if len(uninitialized_encoder_weights) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||||
) | ) | ||||
@@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | ||||
if len(unexpected_keys) > 0: | if len(unexpected_keys) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | ||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | ||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | ||||
@@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig): | |||||
# ensure backward compatibility for BART CNN models | # ensure backward compatibility for BART CNN models | ||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | ||||
self.forced_bos_token_id = self.bos_token_id | self.forced_bos_token_id = self.bos_token_id | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | ||||
"The config can simply be saved and uploaded again to be fixed." | "The config can simply be saved and uploaded again to be fixed." | ||||
) | ) |