Browse Source

Merge remote-tracking branch 'origin/dev0.8.0' into deepspeed

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
0e865d292d
28 changed files with 82 additions and 76 deletions
  1. +1
    -1
      fastNLP/core/callbacks/fitlog_callback.py
  2. +10
    -8
      fastNLP/core/callbacks/load_best_model_callback.py
  3. +5
    -7
      fastNLP/core/callbacks/progress_callback.py
  4. +8
    -7
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  5. +9
    -6
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  6. +3
    -2
      fastNLP/core/dataloaders/prepare_dataloader.py
  7. +8
    -7
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  8. +1
    -1
      fastNLP/core/metrics/accuracy.py
  9. +1
    -1
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  10. +1
    -1
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  11. +1
    -1
      fastNLP/core/utils/utils.py
  12. +2
    -2
      fastNLP/embeddings/torch/static_embedding.py
  13. +2
    -2
      fastNLP/io/embed_loader.py
  14. +1
    -1
      fastNLP/io/loader/classification.py
  15. +3
    -3
      fastNLP/io/loader/matching.py
  16. +2
    -2
      fastNLP/io/pipe/matching.py
  17. +1
    -1
      fastNLP/io/pipe/utils.py
  18. +1
    -1
      fastNLP/modules/mix_modules/utils.py
  19. +1
    -1
      fastNLP/transformers/torch/configuration_utils.py
  20. +1
    -1
      fastNLP/transformers/torch/generation_beam_search.py
  21. +7
    -7
      fastNLP/transformers/torch/generation_utils.py
  22. +1
    -1
      fastNLP/transformers/torch/models/auto/auto_factory.py
  23. +1
    -1
      fastNLP/transformers/torch/models/auto/configuration_auto.py
  24. +2
    -2
      fastNLP/transformers/torch/models/auto/modeling_auto.py
  25. +1
    -1
      fastNLP/transformers/torch/models/bart/modeling_bart.py
  26. +1
    -1
      fastNLP/transformers/torch/models/bert/modeling_bert.py
  27. +1
    -1
      fastNLP/transformers/torch/models/cpt/modeling_cpt.py
  28. +6
    -6
      fastNLP/transformers/torch/tokenization_utils_base.py

+ 1
- 1
fastNLP/core/callbacks/fitlog_callback.py View File

@@ -44,7 +44,7 @@ class FitlogCallback(HasMonitorCallback):
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog
fitlog.debug() fitlog.debug()
super().on_after_trainer_initialized(trainer, driver) super().on_after_trainer_initialized(trainer, driver)
fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME'])
fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME'])


def on_sanity_check_end(self, trainer, sanity_check_res): def on_sanity_check_end(self, trainer, sanity_check_res):
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res)


+ 10
- 8
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -105,14 +105,16 @@ class LoadBestModelCallback(HasMonitorCallback):


def on_train_end(self, trainer): def on_train_end(self, trainer):
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
# 如果是分布式且报错了,就不要加载了,防止barrier的问题
if not (trainer.driver.is_distributed() and self.encounter_exception):
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after: if self.delete_after_after:
if not self.encounter_exception: # 防止出现死锁。 if not self.encounter_exception: # 防止出现死锁。
trainer.driver.barrier() trainer.driver.barrier()


+ 5
- 7
fastNLP/core/callbacks/progress_callback.py View File

@@ -22,9 +22,10 @@ class ProgressCallback(HasMonitorCallback):
self.best_monitor_step = -1 self.best_monitor_step = -1
self.best_results = None self.best_results = None


def record_better_monitor(self, trainer):
def record_better_monitor(self, trainer, results):
self.best_monitor_step = trainer.global_forward_batches self.best_monitor_step = trainer.global_forward_batches
self.best_monitor_epoch = trainer.cur_epoch_idx self.best_monitor_epoch = trainer.cur_epoch_idx
self.best_results = self.itemize_results(results)


def on_train_end(self, trainer): def on_train_end(self, trainer):
if self.best_monitor_epoch != -1: if self.best_monitor_epoch != -1:
@@ -138,7 +139,7 @@ class RichCallback(ProgressCallback):
characters = '-' characters = '-'
if self.monitor is not None: if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'): if abs(self.monitor_value) != float('inf'):
rule_style = 'spring_green3' rule_style = 'spring_green3'
text_style = '[bold]' text_style = '[bold]'
@@ -154,7 +155,6 @@ class RichCallback(ProgressCallback):
self.progress_bar.console.print_json(results) self.progress_bar.console.print_json(results)
else: else:
self.progress_bar.print(results) self.progress_bar.print(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():
@@ -222,7 +222,7 @@ class RawTextCallback(ProgressCallback):
text = '' text = ''
if self.monitor is not None: if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'): if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0: if len(text) == 0:
@@ -234,7 +234,6 @@ class RawTextCallback(ProgressCallback):
if self.format_json: if self.format_json:
results = json.dumps(results) results = json.dumps(results)
logger.info(results) logger.info(results)
self.best_results = results


@property @property
def name(self): # progress bar的名称 def name(self): # progress bar的名称
@@ -311,7 +310,7 @@ class TqdmCallback(ProgressCallback):
text = '' text = ''
if self.monitor is not None: if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'): if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0: if len(text) == 0:
@@ -323,7 +322,6 @@ class TqdmCallback(ProgressCallback):
if self.format_json: if self.format_json:
results = json.dumps(results) results = json.dumps(results)
logger.info(results) logger.info(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():


+ 8
- 7
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -200,7 +200,7 @@ class JittorDataLoader:
return self.cur_batch_indices return self.cur_batch_indices




def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False,
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto", collate_fn: Union[None, str, Callable] = "auto",
@@ -230,7 +230,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否打乱数据集, 默认为 ``False``。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; :param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
:param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 :param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快
@@ -258,7 +259,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
dl_bundle = {} dl_bundle = {}
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -267,7 +268,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
else: else:
dl_bundle[name] = JittorDataLoader(ds, dl_bundle[name] = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -279,14 +280,14 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
ds_dict = {} ds_dict = {}
for name, ds in ds_or_db.items(): for name, ds in ds_or_db.items():
if 'train' in name: if 'train' in name:
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn) collate_fn=collate_fn)
else: else:
dl = JittorDataLoader(ds, dl = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -296,7 +297,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
return ds_dict return ds_dict


elif isinstance(ds_or_db, HasLenGetitemType): elif isinstance(ds_or_db, HasLenGetitemType):
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle,
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn) collate_fn=collate_fn)


+ 9
- 6
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -293,7 +293,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
@@ -326,7 +327,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
return_list=return_list, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=shuffle,
shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader, use_buffer_reader=use_buffer_reader,
@@ -337,7 +338,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list=return_list, return_list=return_list,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader, use_buffer_reader=use_buffer_reader,
@@ -350,7 +351,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
for name, ds in ds_or_db.items(): for name, ds in ds_or_db.items():
if 'train' in name: if 'train' in name:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -359,7 +361,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -369,7 +371,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,


elif isinstance(ds_or_db, HasLenGetitemType): elif isinstance(ds_or_db, HasLenGetitemType):
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)


+ 3
- 2
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -13,7 +13,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger from ..log import logger




def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False,
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False,
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
backend: str = 'auto'): backend: str = 'auto'):
""" """
@@ -28,7 +28,8 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。


:param batch_size: 批次大小。 :param batch_size: 批次大小。
:param shuffle: 是否打乱数据集。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 :param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。
:param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值:




+ 8
- 7
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -218,7 +218,7 @@ class TorchDataLoader(DataLoader):


def prepare_torch_dataloader(ds_or_db, def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16, batch_size: int = 16,
shuffle: bool = False,
shuffle: bool = None,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
@@ -252,7 +252,8 @@ def prepare_torch_dataloader(ds_or_db,


:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否打乱数据集, 默认为 ``False``。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
默认为None, 当其不为 None 时, shuffle 参数无效。 默认为None, 当其不为 None 时, shuffle 参数无效。
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
@@ -290,7 +291,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -300,7 +301,7 @@ def prepare_torch_dataloader(ds_or_db,
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
sampler=non_train_sampler if non_train_sampler else sampler, sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
@@ -316,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.items(): for name, ds in ds_or_db.items():
if 'train' in name: if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -326,7 +327,7 @@ def prepare_torch_dataloader(ds_or_db,
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
sampler=non_train_sampler if non_train_sampler else sampler, sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
@@ -340,7 +341,7 @@ def prepare_torch_dataloader(ds_or_db,


elif isinstance(ds_or_db, HasLenGetitemType): elif isinstance(ds_or_db, HasLenGetitemType):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,


+ 1
- 1
fastNLP/core/metrics/accuracy.py View File

@@ -69,7 +69,7 @@ class Accuracy(Metric):
elif pred.ndim == target.ndim + 1: elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1) pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1: if seq_len is None and target.ndim > 1:
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.")


else: else:
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "


+ 1
- 1
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -156,7 +156,7 @@ class ClassifyFPreRecMetric(Metric):
elif pred.ndim == target.ndim + 1: elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1) pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1: if seq_len is None and target.ndim > 1:
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else: else:
raise RuntimeError(f"when pred have " raise RuntimeError(f"when pred have "
f"size:{pred.shape}, target should have size: {pred.shape} or " f"size:{pred.shape}, target should have size: {pred.shape} or "


+ 1
- 1
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod
f"encoding_type." f"encoding_type."
tags = tags.replace(tag, '') # 删除该值 tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空,说明出现了未使用的tag if tags: # 如果不为空,说明出现了未使用的tag
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
logger.warning(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.") "encoding_type.")






+ 1
- 1
fastNLP/core/utils/utils.py View File

@@ -554,7 +554,7 @@ def deprecated(help_message: Optional[str] = None):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
func_hash = hash(deprecated_function) func_hash = hash(deprecated_function)
if func_hash not in _emitted_deprecation_warnings: if func_hash not in _emitted_deprecation_warnings:
logger.warn(warning_msg, category=FutureWarning, stacklevel=2)
logger.warning(warning_msg, category=FutureWarning, stacklevel=2)
_emitted_deprecation_warnings.add(func_hash) _emitted_deprecation_warnings.add(func_hash)
return deprecated_function(*args, **kwargs) return deprecated_function(*args, **kwargs)




+ 2
- 2
fastNLP/embeddings/torch/static_embedding.py View File

@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab: if word in vocab:
index = vocab.to_index(word) index = vocab.to_index(word)
if index in matrix: if index in matrix:
logger.warn(f"Word has more than one vector in embedding file. Set logger level to "
logger.warning(f"Word has more than one vector in embedding file. Set logger level to "
f"DEBUG for detail.") f"DEBUG for detail.")
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)")
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding):
found_count += 1 found_count += 1
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
logger.warn("Error occurred at the {} line.".format(idx))
logger.warning("Error occurred at the {} line.".format(idx))
else: else:
logger.error("Error occurred at the {} line.".format(idx)) logger.error("Error occurred at the {} line.".format(idx))
raise e raise e


+ 2
- 2
fastNLP/io/embed_loader.py View File

@@ -91,7 +91,7 @@ class EmbedLoader:
hit_flags[index] = True hit_flags[index] = True
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
logger.warn("Error occurred at the {} line.".format(idx))
logger.warning("Error occurred at the {} line.".format(idx))
else: else:
logging.error("Error occurred at the {} line.".format(idx)) logging.error("Error occurred at the {} line.".format(idx))
raise e raise e
@@ -156,7 +156,7 @@ class EmbedLoader:
found_pad = True found_pad = True
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
logger.warn("Error occurred at the {} line.".format(idx))
logger.warning("Error occurred at the {} line.".format(idx))
pass pass
else: else:
logging.error("Error occurred at the {} line.".format(idx)) logging.error("Error occurred at the {} line.".format(idx))


+ 1
- 1
fastNLP/io/loader/classification.py View File

@@ -345,7 +345,7 @@ class SST2Loader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if 'test' in os.path.split(path)[1]: if 'test' in os.path.split(path)[1]:
logger.warn("SST2's test file has no target.")
logger.warning("SST2's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:


+ 3
- 3
fastNLP/io/loader/matching.py View File

@@ -55,7 +55,7 @@ class MNLILoader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
logger.warn("MNLI's test file has no target.")
logger.warning("MNLI's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
logger.warn("QNLI's test file has no target.")
logger.warning("QNLI's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:
@@ -289,7 +289,7 @@ class RTELoader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
logger.warn("RTE's test file has no target.")
logger.warning("RTE's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:


+ 2
- 2
fastNLP/io/pipe/matching.py View File

@@ -146,7 +146,7 @@ class MatchingBertPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!." f"data set but not in train data set!."
logger.warn(warn_msg)
logger.warning(warn_msg)
print(warn_msg) print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
@@ -291,7 +291,7 @@ class MatchingPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!." f"data set but not in train data set!."
logger.warn(warn_msg)
logger.warning(warn_msg)
print(warn_msg) print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if


+ 1
- 1
fastNLP/io/pipe/utils.py View File

@@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!.\n" \ f"data set but not in train data set!.\n" \
f"These label(s) are {tgt_vocab._no_create_word}" f"These label(s) are {tgt_vocab._no_create_word}"
logger.warn(warn_msg)
logger.warning(warn_msg)
# log.warning(warn_msg) # log.warning(warn_msg)
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name)


+ 1
- 1
fastNLP/modules/mix_modules/utils.py View File

@@ -112,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] =
# 如果outputs有_grad键,可以实现求导 # 如果outputs有_grad键,可以实现求导
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient
if no_gradient == False: if no_gradient == False:
logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
logger.warning("The result tensor will not keep gradients due to differences between jittor and pytorch.")
jittor_numpy = jittor_var.numpy() jittor_numpy = jittor_var.numpy()
if not np.issubdtype(jittor_numpy.dtype, np.inexact): if not np.issubdtype(jittor_numpy.dtype, np.inexact):
no_gradient = True no_gradient = True


+ 1
- 1
fastNLP/transformers/torch/configuration_utils.py View File

@@ -327,7 +327,7 @@ class PretrainedConfig:


# Deal with gradient checkpointing # Deal with gradient checkpointing
if kwargs.get("gradient_checkpointing", False): if kwargs.get("gradient_checkpointing", False):
logger.warn(
logger.warning(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."


+ 1
- 1
fastNLP/transformers/torch/generation_beam_search.py View File

@@ -195,7 +195,7 @@ class BeamSearchScorer(BeamScorer):
) )


if "max_length" in kwargs: if "max_length" in kwargs:
logger.warn(
logger.warning(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect." "Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
",or `group_beam_search(...)`." ",or `group_beam_search(...)`."


+ 7
- 7
fastNLP/transformers/torch/generation_utils.py View File

@@ -872,7 +872,7 @@ class GenerationMixin:
max_length = self.config.max_length max_length = self.config.max_length
elif max_length is not None and max_new_tokens is not None: elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning # Both are set, this is odd, raise a warning
logger.warn(
logger.warning(
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning
) )


@@ -1239,7 +1239,7 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
logger.warn(
logger.warning(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )
@@ -1475,7 +1475,7 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
logger.warn(
logger.warning(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )
@@ -1726,13 +1726,13 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
logger.warn(
logger.warning(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0: if len(stopping_criteria) == 0:
logger.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
logger.warning("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
@@ -2030,7 +2030,7 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
logger.warn(
logger.warning(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )
@@ -2325,7 +2325,7 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
logger.warn(
logger.warning(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )


+ 1
- 1
fastNLP/transformers/torch/models/auto/auto_factory.py View File

@@ -401,7 +401,7 @@ class _BaseAutoModelClass:
"the option `trust_remote_code=True` to remove this error." "the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None: if kwargs.get("revision", None) is None:
logger.warn(
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision." "no malicious code has been contributed in a newer revision."
) )


+ 1
- 1
fastNLP/transformers/torch/models/auto/configuration_auto.py View File

@@ -130,7 +130,7 @@ class _LazyLoadAllMappings(OrderedDict):
def _initialize(self): def _initialize(self):
if self._initialized: if self._initialized:
return return
# logger.warn(
# logger.warning(
# "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " # "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
# "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", # "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
# FutureWarning, # FutureWarning,


+ 2
- 2
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -306,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
class AutoModelWithLMHead(_AutoModelWithLMHead): class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
logger.warn(
logger.warning(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
@@ -316,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):


@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
logger.warn(
logger.warning(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", "`AutoModelForSeq2SeqLM` for encoder-decoder models.",


+ 1
- 1
fastNLP/transformers/torch/models/bart/modeling_bart.py View File

@@ -513,7 +513,7 @@ class BartPretrainedModel(PreTrainedModel):


class PretrainedBartModel(BartPretrainedModel): class PretrainedBartModel(BartPretrainedModel):
def __init_subclass__(self): def __init_subclass__(self):
logger.warn(
logger.warning(
"The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
FutureWarning, FutureWarning,
) )


+ 1
- 1
fastNLP/transformers/torch/models/bert/modeling_bert.py View File

@@ -1374,7 +1374,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
""" """


if "next_sentence_label" in kwargs: if "next_sentence_label" in kwargs:
logger.warn(
logger.warning(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning, FutureWarning,
) )


+ 1
- 1
fastNLP/transformers/torch/models/cpt/modeling_cpt.py View File

@@ -724,7 +724,7 @@ class CPTDecoder(CPTPretrainedModel):
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:


if use_cache: if use_cache:
logger.warn(
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..." "`use_cache=False`..."
) )


+ 6
- 6
fastNLP/transformers/torch/tokenization_utils_base.py View File

@@ -312,7 +312,7 @@ class BatchEncoding(UserDict):
""" """
if not self._encodings: if not self._encodings:
raise ValueError("words() is not available when using Python-based tokenizers") raise ValueError("words() is not available when using Python-based tokenizers")
logger.warn(
logger.warning(
"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
"but more self-explanatory `BatchEncoding.word_ids()` property.", "but more self-explanatory `BatchEncoding.word_ids()` property.",
FutureWarning, FutureWarning,
@@ -1601,7 +1601,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
"supported for this tokenizer. Use a model identifier or the path to a directory instead." "supported for this tokenizer. Use a model identifier or the path to a directory instead."
) )
logger.warn(
logger.warning(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and "
"won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.",
FutureWarning, FutureWarning,
@@ -2163,7 +2163,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Get padding strategy # Get padding strategy
if padding is False and old_pad_to_max_length: if padding is False and old_pad_to_max_length:
if verbose: if verbose:
logger.warn(
logger.warning(
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, " "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific " "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
@@ -2184,7 +2184,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"To pad to max length, use `padding='max_length'`." "To pad to max length, use `padding='max_length'`."
) )
if old_pad_to_max_length is not False: if old_pad_to_max_length is not False:
logger.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.")
logger.warning("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.")
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy): elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding) padding_strategy = PaddingStrategy(padding)
@@ -2196,7 +2196,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Get truncation strategy # Get truncation strategy
if truncation is False and old_truncation_strategy != "do_not_truncate": if truncation is False and old_truncation_strategy != "do_not_truncate":
if verbose: if verbose:
logger.warn(
logger.warning(
"The `truncation_strategy` argument is deprecated and will be removed in a future version, " "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
"use `truncation=True` to truncate examples to a max length. You can give a specific " "use `truncation=True` to truncate examples to a max length. You can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
@@ -3352,7 +3352,7 @@ model_inputs["labels"] = labels["input_ids"]
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
For a more complete example, see the implementation of `prepare_seq2seq_batch`. For a more complete example, see the implementation of `prepare_seq2seq_batch`.
""" """
logger.warn(formatted_warning, FutureWarning)
logger.warning(formatted_warning, FutureWarning)
# mBART-specific kwargs that should be ignored by other models. # mBART-specific kwargs that should be ignored by other models.
kwargs.pop("src_lang", None) kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None) kwargs.pop("tgt_lang", None)


Loading…
Cancel
Save