From 36174f727d482f829e3965adb83bc5e9119c5597 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 11 May 2022 19:18:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0TransformersAccuracy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 + fastNLP/core/callbacks/callback_manager.py | 2 +- .../controllers/loops/evaluate_batch_loop.py | 14 ++++--- .../controllers/loops/train_batch_loop.py | 27 ++++++------ .../core/dataloaders/jittor_dataloader/fdl.py | 2 +- .../core/dataloaders/paddle_dataloader/fdl.py | 2 +- .../core/dataloaders/prepare_dataloader.py | 2 +- .../core/dataloaders/torch_dataloader/fdl.py | 6 +-- fastNLP/core/dataloaders/utils.py | 3 +- .../drivers/torch_driver/single_device.py | 42 +++++++++++-------- fastNLP/core/metrics/__init__.py | 3 +- fastNLP/core/metrics/accuracy.py | 30 ++++++++++--- fastNLP/io/data_bundle.py | 6 +-- .../transformers/torch/configuration_utils.py | 4 +- .../torch/generation_stopping_criteria.py | 2 +- .../transformers/torch/generation_utils.py | 4 +- fastNLP/transformers/torch/modeling_utils.py | 6 +-- .../torch/models/bart/configuration_bart.py | 2 +- 18 files changed, 94 insertions(+), 64 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index fc47b470..343313a6 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -69,6 +69,7 @@ __all__ = [ # metrics "Metric", "Accuracy", + "TransformersAccuracy", 'SpanFPreRecMetric', 'ClassifyFPreRecMetric', diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 27770115..765a0346 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -25,7 +25,7 @@ def _transfer(func): for callback_fn in manager.callback_fns[func.__name__]: try: callback_fn(*arg, **kwargs) - except EarlyStopException as e: + except (EarlyStopException, KeyboardInterrupt) as e: raise e except BaseException as e: logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index 80c234cd..c81379a1 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -27,19 +27,21 @@ class EvaluateBatchLoop(Loop): while True: try: batch = next(iterator) - batch = match_and_substitute_params(evaluator.input_mapping, batch) - batch = evaluator.move_data_to_device(batch) except StopIteration: break + try: + batch = match_and_substitute_params(evaluator.input_mapping, batch) + batch = evaluator.move_data_to_device(batch) + + self.batch_step_fn(evaluator, batch) + batch_idx += 1 + evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) + except BaseException as e: if callable(getattr(dataloader, 'get_batch_indices', None)): indices = dataloader.get_batch_indices() logger.error(f"Exception happens when evaluating on samples: {indices}") raise e - - self.batch_step_fn(evaluator, batch) - batch_idx += 1 - evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} results = evaluator.get_metric() return results diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 989fb2ae..7bb9b653 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -19,30 +19,31 @@ class TrainBatchLoop(Loop): get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ else lambda *args, **kwargs: None dataloader = iter(dataloader) - indices = None while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: try: trainer.on_fetch_data_begin() batch = next(dataloader) indices = get_batch_indices() + except StopIteration: + break + + try: trainer.on_fetch_data_end() batch = match_and_substitute_params(trainer.input_mapping, batch) batch = trainer.move_data_to_device(batch) - except StopIteration: - break + + trainer.on_train_batch_begin(batch, indices) + with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync + self.batch_step_fn(trainer, batch) + trainer.global_forward_batches += 1 + trainer.batch_idx_in_epoch += 1 + + trainer.check_batch_step_fn() + trainer.on_train_batch_end() except BaseException as e: - if indices and not isinstance(e, EarlyStopException): + if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): logger.error(f"Exception happens when running on samples: {indices}") raise e - - trainer.on_train_batch_begin(batch, indices) - with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync - self.batch_step_fn(trainer, batch) - trainer.global_forward_batches += 1 - trainer.batch_idx_in_epoch += 1 - - trainer.check_batch_step_fn() - trainer.on_train_batch_end() trainer.step_evaluate() trainer.batch_idx_in_epoch = 0 diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 8ecd2d87..b76fd4c1 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -47,7 +47,7 @@ class JittorDataLoader: 提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset """ - def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, + def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto") -> None: diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 393324d4..4c2f2300 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -47,7 +47,7 @@ class PaddleDataLoader(DataLoader): def __init__(self, dataset, feed_list=None, places=None, return_list: bool = True, batch_sampler=None, - batch_size: int = 1, shuffle: bool = False, + batch_size: int = 1, shuffle: bool = True, drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py index 8a7e3d1e..33764c6f 100644 --- a/fastNLP/core/dataloaders/prepare_dataloader.py +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS from ..log import logger -def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, +def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, seed: int = 0, backend: str = 'auto'): """ diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 456af44f..726abaae 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -179,7 +179,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, - shuffle: bool = False, + shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', @@ -236,8 +236,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping persistent_workers=persistent_workers, ) else: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, - shuffle=shuffle, sampler=non_train_sampler, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, + shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 39ce5983..495fb6d3 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,9 +1,10 @@ +from typing import Callable __all__ = [ "indice_collate_wrapper" ] -def indice_collate_wrapper(func): +def indice_collate_wrapper(func:Callable): """ 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 6c125a73..8aa9a2d5 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -1,11 +1,13 @@ import os from typing import Dict, Union, Callable, Tuple, Optional from fastNLP.envs.imports import _NEED_IMPORT_TORCH + if _NEED_IMPORT_TORCH: import torch from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel from torch.utils.data import RandomSampler as TorchRandomSampler + from torch.utils.data import SequentialSampler as TorchSequentialSampler __all__ = [ 'TorchSingleDriver' @@ -15,7 +17,8 @@ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call from fastNLP.core.utils.utils import _get_fun_msg -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ + ReproduceBatchSampler from fastNLP.core.samplers import RandomSampler from fastNLP.core.log import logger @@ -24,6 +27,7 @@ class TorchSingleDriver(TorchDriver): r""" 用于 cpu 和 单卡 gpu 运算; """ + def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): if isinstance(model, DistributedDataParallel): raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") @@ -88,7 +92,8 @@ class TorchSingleDriver(TorchDriver): else: raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, + def set_dist_repro_dataloader(self, dataloader, + dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, reproducible: bool = False): # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; @@ -108,17 +113,24 @@ class TorchSingleDriver(TorchDriver): if reproducible: if isinstance(args.sampler, TorchRandomSampler): - # 如果本来就是随机的,直接替换掉吧。 - sampler = RandomSampler(args.sampler.data_source) - logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉吧。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif isinstance(args.sampler, TorchSequentialSampler): + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - else: - batch_sampler = ReproduceBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = ReproduceBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader @@ -138,9 +150,3 @@ class TorchSingleDriver(TorchDriver): def is_distributed(self): return False - - - - - - diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py index f7d60606..b7f572e8 100644 --- a/fastNLP/core/metrics/__init__.py +++ b/fastNLP/core/metrics/__init__.py @@ -1,11 +1,12 @@ __all__ = [ "Metric", "Accuracy", + "TransformersAccuracy", 'SpanFPreRecMetric', 'ClassifyFPreRecMetric', ] from .metric import Metric -from .accuracy import Accuracy +from .accuracy import Accuracy, TransformersAccuracy from .span_f1_pre_rec_metric import SpanFPreRecMetric from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 0869d8c8..59990f95 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -1,5 +1,6 @@ __all__ = [ - 'Accuracy' + 'Accuracy', + "TransformersAccuracy" ] from typing import Union @@ -17,9 +18,9 @@ class Accuracy(Metric): """ 计算 准确率 的 metric 。 - :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() + :param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 - :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, + :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 """ super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) @@ -39,11 +40,11 @@ class Accuracy(Metric): r""" update 函数将针对一个批次的预测结果做评价指标的累计 - :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + :param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) - :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + :param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) - :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). + :param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). 如果mask也被传进来的话seq_len会被忽略. """ # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 @@ -79,3 +80,20 @@ class Accuracy(Metric): else: self.total += np.prod(list(pred.shape)).item() self.correct += (target == pred).sum().item() + + +class TransformersAccuracy(Accuracy): + """ + 适配 transformers 中相关模型的 Accuracy metric 。 + + """ + def update(self, logits, labels, attention_mask=None): + r""" + update 函数将针对一个批次的预测结果做评价指标的累计 + + :param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。 + :param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]`` + :param attention_mask: 序列长度标记。 + """ + seq_len = attention_mask.sum(dim=-1) + super().update(pred=logits, target=labels, seq_len=seq_len) \ No newline at end of file diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index a3c15a28..df194df2 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -249,7 +249,7 @@ class DataBundle: return self def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, - ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): + ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 @@ -263,8 +263,8 @@ class DataBundle: :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 - :param show_progress_bar: 是否显示tqdm进度条 - :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 + :param show_progress_bar: 是否显示进度条 + :param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 diff --git a/fastNLP/transformers/torch/configuration_utils.py b/fastNLP/transformers/torch/configuration_utils.py index 9c17f336..fb494d9f 100644 --- a/fastNLP/transformers/torch/configuration_utils.py +++ b/fastNLP/transformers/torch/configuration_utils.py @@ -314,7 +314,7 @@ class PretrainedConfig: # TPU arguments if kwargs.pop("xla_device", None) is not None: - logger.warning( + logger.rank_zero_warning( "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " "safely remove it from your `config.json` file." ) @@ -474,7 +474,7 @@ class PretrainedConfig: """ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warn( + logger.rank_zero_warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) diff --git a/fastNLP/transformers/torch/generation_stopping_criteria.py b/fastNLP/transformers/torch/generation_stopping_criteria.py index 179bf7c1..da2bcf9b 100644 --- a/fastNLP/transformers/torch/generation_stopping_criteria.py +++ b/fastNLP/transformers/torch/generation_stopping_criteria.py @@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng stopping_max_length = stopping_criteria.max_length new_stopping_criteria = deepcopy(stopping_criteria) if stopping_max_length is not None and stopping_max_length != max_length: - logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) + logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) elif stopping_max_length is None: new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) return new_stopping_criteria diff --git a/fastNLP/transformers/torch/generation_utils.py b/fastNLP/transformers/torch/generation_utils.py index cfc2108c..0e6fe5c7 100644 --- a/fastNLP/transformers/torch/generation_utils.py +++ b/fastNLP/transformers/torch/generation_utils.py @@ -429,7 +429,7 @@ class GenerationMixin: def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: if pad_token_id is None and eos_token_id is not None: - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id return pad_token_id @@ -912,7 +912,7 @@ class GenerationMixin: # special case if pad_token_id is not defined if pad_token_id is None and eos_token_id is not None: - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id # Storing encoder_input_ids for logits_processor that could use them diff --git a/fastNLP/transformers/torch/modeling_utils.py b/fastNLP/transformers/torch/modeling_utils.py index d1d5c2f3..d19816a3 100644 --- a/fastNLP/transformers/torch/modeling_utils.py +++ b/fastNLP/transformers/torch/modeling_utils.py @@ -352,7 +352,7 @@ class ModuleUtilsMixin: if token_inputs: return sum([token_input.numel() for token_input in token_inputs]) else: - logger.warn( + logger.rank_zero_warning( "Could not estimate the number of tokens of the input, floating-point operations will not be computed" ) return 0 @@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): # tie weights recursively tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) if len(uninitialized_encoder_weights) > 0: - logger.warning( + logger.rank_zero_warning( f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" ) @@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: - logger.warning( + logger.rank_zero_warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " diff --git a/fastNLP/transformers/torch/models/bart/configuration_bart.py b/fastNLP/transformers/torch/models/bart/configuration_bart.py index 3b52bc81..9465326b 100644 --- a/fastNLP/transformers/torch/models/bart/configuration_bart.py +++ b/fastNLP/transformers/torch/models/bart/configuration_bart.py @@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig): # ensure backward compatibility for BART CNN models if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): self.forced_bos_token_id = self.bos_token_id - logger.warn( + logger.rank_zero_warning( f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." "The config can simply be saved and uploaded again to be fixed." )