@@ -44,7 +44,7 @@ class FitlogCallback(HasMonitorCallback): | |||||
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | ||||
fitlog.debug() | fitlog.debug() | ||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME']) | |||||
fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME']) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | ||||
@@ -105,14 +105,16 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | ||||
if self.real_save_folder: | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
# 如果是分布式且报错了,就不要加载了,防止barrier的问题 | |||||
if not (trainer.driver.is_distributed() and self.encounter_exception): | |||||
if self.real_save_folder: | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
if self.delete_after_after: | if self.delete_after_after: | ||||
if not self.encounter_exception: # 防止出现死锁。 | if not self.encounter_exception: # 防止出现死锁。 | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -22,9 +22,10 @@ class ProgressCallback(HasMonitorCallback): | |||||
self.best_monitor_step = -1 | self.best_monitor_step = -1 | ||||
self.best_results = None | self.best_results = None | ||||
def record_better_monitor(self, trainer): | |||||
def record_better_monitor(self, trainer, results): | |||||
self.best_monitor_step = trainer.global_forward_batches | self.best_monitor_step = trainer.global_forward_batches | ||||
self.best_monitor_epoch = trainer.cur_epoch_idx | self.best_monitor_epoch = trainer.cur_epoch_idx | ||||
self.best_results = self.itemize_results(results) | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
if self.best_monitor_epoch != -1: | if self.best_monitor_epoch != -1: | ||||
@@ -138,7 +139,7 @@ class RichCallback(ProgressCallback): | |||||
characters = '-' | characters = '-' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
self.record_better_monitor(trainer) | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
text_style = '[bold]' | text_style = '[bold]' | ||||
@@ -154,7 +155,6 @@ class RichCallback(ProgressCallback): | |||||
self.progress_bar.console.print_json(results) | self.progress_bar.console.print_json(results) | ||||
else: | else: | ||||
self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
self.best_results = results | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
@@ -222,7 +222,7 @@ class RawTextCallback(ProgressCallback): | |||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
self.record_better_monitor(trainer) | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
@@ -234,7 +234,6 @@ class RawTextCallback(ProgressCallback): | |||||
if self.format_json: | if self.format_json: | ||||
results = json.dumps(results) | results = json.dumps(results) | ||||
logger.info(results) | logger.info(results) | ||||
self.best_results = results | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
@@ -311,7 +310,7 @@ class TqdmCallback(ProgressCallback): | |||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
self.record_better_monitor(trainer) | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
@@ -323,7 +322,6 @@ class TqdmCallback(ProgressCallback): | |||||
if self.format_json: | if self.format_json: | ||||
results = json.dumps(results) | results = json.dumps(results) | ||||
logger.info(results) | logger.info(results) | ||||
self.best_results = results | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
@@ -200,7 +200,7 @@ class JittorDataLoader: | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, | |||||
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto", | collate_fn: Union[None, str, Callable] = "auto", | ||||
@@ -230,7 +230,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | ||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | ||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | ||||
:param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 | :param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 | ||||
@@ -258,7 +259,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -267,7 +268,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
else: | else: | ||||
dl_bundle[name] = JittorDataLoader(ds, | dl_bundle[name] = JittorDataLoader(ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -279,14 +280,14 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
ds_dict = {} | ds_dict = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl = JittorDataLoader(ds, | dl = JittorDataLoader(ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -296,7 +297,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
return ds_dict | return ds_dict | ||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
@@ -293,7 +293,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | ||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | ||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | ||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | ||||
@@ -326,7 +327,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | ||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, | batch_sampler=batch_sampler, batch_size=batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
use_buffer_reader=use_buffer_reader, | use_buffer_reader=use_buffer_reader, | ||||
@@ -337,7 +338,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
use_buffer_reader=use_buffer_reader, | use_buffer_reader=use_buffer_reader, | ||||
@@ -350,7 +351,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -359,7 +361,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -369,7 +371,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
@@ -13,7 +13,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | |||||
from ..log import logger | from ..log import logger | ||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False, | |||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
backend: str = 'auto'): | backend: str = 'auto'): | ||||
""" | """ | ||||
@@ -28,7 +28,8 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro | |||||
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 | * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 | ||||
:param batch_size: 批次大小。 | :param batch_size: 批次大小。 | ||||
:param shuffle: 是否打乱数据集。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 | :param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 | ||||
:param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | ||||
@@ -218,7 +218,7 @@ class TorchDataLoader(DataLoader): | |||||
def prepare_torch_dataloader(ds_or_db, | def prepare_torch_dataloader(ds_or_db, | ||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = False, | |||||
shuffle: bool = None, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | ||||
@@ -252,7 +252,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | ||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | 默认为None, 当其不为 None 时, shuffle 参数无效。 | ||||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | ||||
@@ -290,7 +291,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -300,7 +301,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | sampler=non_train_sampler if non_train_sampler else sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
@@ -316,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -326,7 +327,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | sampler=non_train_sampler if non_train_sampler else sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
@@ -340,7 +341,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -69,7 +69,7 @@ class Accuracy(Metric): | |||||
elif pred.ndim == target.ndim + 1: | elif pred.ndim == target.ndim + 1: | ||||
pred = pred.argmax(axis=-1) | pred = pred.argmax(axis=-1) | ||||
if seq_len is None and target.ndim > 1: | if seq_len is None and target.ndim > 1: | ||||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | else: | ||||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | ||||
@@ -156,7 +156,7 @@ class ClassifyFPreRecMetric(Metric): | |||||
elif pred.ndim == target.ndim + 1: | elif pred.ndim == target.ndim + 1: | ||||
pred = pred.argmax(axis=-1) | pred = pred.argmax(axis=-1) | ||||
if seq_len is None and target.ndim > 1: | if seq_len is None and target.ndim > 1: | ||||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | else: | ||||
raise RuntimeError(f"when pred have " | raise RuntimeError(f"when pred have " | ||||
f"size:{pred.shape}, target should have size: {pred.shape} or " | f"size:{pred.shape}, target should have size: {pred.shape} or " | ||||
@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod | |||||
f"encoding_type." | f"encoding_type." | ||||
tags = tags.replace(tag, '') # 删除该值 | tags = tags.replace(tag, '') # 删除该值 | ||||
if tags: # 如果不为空,说明出现了未使用的tag | if tags: # 如果不为空,说明出现了未使用的tag | ||||
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||||
logger.warning(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||||
"encoding_type.") | "encoding_type.") | ||||
@@ -554,7 +554,7 @@ def deprecated(help_message: Optional[str] = None): | |||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
func_hash = hash(deprecated_function) | func_hash = hash(deprecated_function) | ||||
if func_hash not in _emitted_deprecation_warnings: | if func_hash not in _emitted_deprecation_warnings: | ||||
logger.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||||
logger.warning(warning_msg, category=FutureWarning, stacklevel=2) | |||||
_emitted_deprecation_warnings.add(func_hash) | _emitted_deprecation_warnings.add(func_hash) | ||||
return deprecated_function(*args, **kwargs) | return deprecated_function(*args, **kwargs) | ||||
@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
if index in matrix: | if index in matrix: | ||||
logger.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||||
logger.warning(f"Word has more than one vector in embedding file. Set logger level to " | |||||
f"DEBUG for detail.") | f"DEBUG for detail.") | ||||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | ||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | ||||
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
found_count += 1 | found_count += 1 | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
else: | else: | ||||
logger.error("Error occurred at the {} line.".format(idx)) | logger.error("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
@@ -91,7 +91,7 @@ class EmbedLoader: | |||||
hit_flags[index] = True | hit_flags[index] = True | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
else: | else: | ||||
logging.error("Error occurred at the {} line.".format(idx)) | logging.error("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
@@ -156,7 +156,7 @@ class EmbedLoader: | |||||
found_pad = True | found_pad = True | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
pass | pass | ||||
else: | else: | ||||
logging.error("Error occurred at the {} line.".format(idx)) | logging.error("Error occurred at the {} line.".format(idx)) | ||||
@@ -345,7 +345,7 @@ class SST2Loader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if 'test' in os.path.split(path)[1]: | if 'test' in os.path.split(path)[1]: | ||||
logger.warn("SST2's test file has no target.") | |||||
logger.warning("SST2's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -55,7 +55,7 @@ class MNLILoader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | ||||
logger.warn("MNLI's test file has no target.") | |||||
logger.warning("MNLI's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
logger.warn("QNLI's test file has no target.") | |||||
logger.warning("QNLI's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -289,7 +289,7 @@ class RTELoader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
logger.warn("RTE's test file has no target.") | |||||
logger.warning("RTE's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -146,7 +146,7 @@ class MatchingBertPipe(Pipe): | |||||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | ||||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | ||||
f"data set but not in train data set!." | f"data set but not in train data set!." | ||||
logger.warn(warn_msg) | |||||
logger.warning(warn_msg) | |||||
print(warn_msg) | print(warn_msg) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||
@@ -291,7 +291,7 @@ class MatchingPipe(Pipe): | |||||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | ||||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | ||||
f"data set but not in train data set!." | f"data set but not in train data set!." | ||||
logger.warn(warn_msg) | |||||
logger.warning(warn_msg) | |||||
print(warn_msg) | print(warn_msg) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||
@@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target | |||||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | ||||
f"data set but not in train data set!.\n" \ | f"data set but not in train data set!.\n" \ | ||||
f"These label(s) are {tgt_vocab._no_create_word}" | f"These label(s) are {tgt_vocab._no_create_word}" | ||||
logger.warn(warn_msg) | |||||
logger.warning(warn_msg) | |||||
# log.warning(warn_msg) | # log.warning(warn_msg) | ||||
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) | tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) | ||||
data_bundle.set_vocab(tgt_vocab, target_field_name) | data_bundle.set_vocab(tgt_vocab, target_field_name) | ||||
@@ -112,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = | |||||
# 如果outputs有_grad键,可以实现求导 | # 如果outputs有_grad键,可以实现求导 | ||||
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient | no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient | ||||
if no_gradient == False: | if no_gradient == False: | ||||
logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||||
logger.warning("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||||
jittor_numpy = jittor_var.numpy() | jittor_numpy = jittor_var.numpy() | ||||
if not np.issubdtype(jittor_numpy.dtype, np.inexact): | if not np.issubdtype(jittor_numpy.dtype, np.inexact): | ||||
no_gradient = True | no_gradient = True | ||||
@@ -327,7 +327,7 @@ class PretrainedConfig: | |||||
# Deal with gradient checkpointing | # Deal with gradient checkpointing | ||||
if kwargs.get("gradient_checkpointing", False): | if kwargs.get("gradient_checkpointing", False): | ||||
logger.warn( | |||||
logger.warning( | |||||
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " | "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " | ||||
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " | "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " | ||||
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." | "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." | ||||
@@ -195,7 +195,7 @@ class BeamSearchScorer(BeamScorer): | |||||
) | ) | ||||
if "max_length" in kwargs: | if "max_length" in kwargs: | ||||
logger.warn( | |||||
logger.warning( | |||||
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect." | "Passing `max_length` to BeamSearchScorer is deprecated and has no effect." | ||||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | ||||
",or `group_beam_search(...)`." | ",or `group_beam_search(...)`." | ||||
@@ -872,7 +872,7 @@ class GenerationMixin: | |||||
max_length = self.config.max_length | max_length = self.config.max_length | ||||
elif max_length is not None and max_new_tokens is not None: | elif max_length is not None and max_new_tokens is not None: | ||||
# Both are set, this is odd, raise a warning | # Both are set, this is odd, raise a warning | ||||
logger.warn( | |||||
logger.warning( | |||||
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning | "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning | ||||
) | ) | ||||
@@ -1239,7 +1239,7 @@ class GenerationMixin: | |||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
if max_length is not None: | if max_length is not None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | ||||
UserWarning, | UserWarning, | ||||
) | ) | ||||
@@ -1475,7 +1475,7 @@ class GenerationMixin: | |||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
if max_length is not None: | if max_length is not None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | ||||
UserWarning, | UserWarning, | ||||
) | ) | ||||
@@ -1726,13 +1726,13 @@ class GenerationMixin: | |||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
if max_length is not None: | if max_length is not None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | ||||
UserWarning, | UserWarning, | ||||
) | ) | ||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | ||||
if len(stopping_criteria) == 0: | if len(stopping_criteria) == 0: | ||||
logger.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) | |||||
logger.warning("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) | |||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | ||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | ||||
output_scores = output_scores if output_scores is not None else self.config.output_scores | output_scores = output_scores if output_scores is not None else self.config.output_scores | ||||
@@ -2030,7 +2030,7 @@ class GenerationMixin: | |||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
if max_length is not None: | if max_length is not None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | ||||
UserWarning, | UserWarning, | ||||
) | ) | ||||
@@ -2325,7 +2325,7 @@ class GenerationMixin: | |||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
if max_length is not None: | if max_length is not None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | ||||
UserWarning, | UserWarning, | ||||
) | ) | ||||
@@ -401,7 +401,7 @@ class _BaseAutoModelClass: | |||||
"the option `trust_remote_code=True` to remove this error." | "the option `trust_remote_code=True` to remove this error." | ||||
) | ) | ||||
if kwargs.get("revision", None) is None: | if kwargs.get("revision", None) is None: | ||||
logger.warn( | |||||
logger.warning( | |||||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " | "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " | ||||
"no malicious code has been contributed in a newer revision." | "no malicious code has been contributed in a newer revision." | ||||
) | ) | ||||
@@ -130,7 +130,7 @@ class _LazyLoadAllMappings(OrderedDict): | |||||
def _initialize(self): | def _initialize(self): | ||||
if self._initialized: | if self._initialized: | ||||
return | return | ||||
# logger.warn( | |||||
# logger.warning( | |||||
# "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | # "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | ||||
# "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | # "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | ||||
# FutureWarning, | # FutureWarning, | ||||
@@ -306,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update( | |||||
class AutoModelWithLMHead(_AutoModelWithLMHead): | class AutoModelWithLMHead(_AutoModelWithLMHead): | ||||
@classmethod | @classmethod | ||||
def from_config(cls, config): | def from_config(cls, config): | ||||
logger.warn( | |||||
logger.warning( | |||||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | ||||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | ||||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | ||||
@@ -316,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead): | |||||
@classmethod | @classmethod | ||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||
logger.warn( | |||||
logger.warning( | |||||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | ||||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | ||||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | ||||
@@ -513,7 +513,7 @@ class BartPretrainedModel(PreTrainedModel): | |||||
class PretrainedBartModel(BartPretrainedModel): | class PretrainedBartModel(BartPretrainedModel): | ||||
def __init_subclass__(self): | def __init_subclass__(self): | ||||
logger.warn( | |||||
logger.warning( | |||||
"The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", | ||||
FutureWarning, | FutureWarning, | ||||
) | ) | ||||
@@ -1374,7 +1374,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): | |||||
""" | """ | ||||
if "next_sentence_label" in kwargs: | if "next_sentence_label" in kwargs: | ||||
logger.warn( | |||||
logger.warning( | |||||
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", | "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", | ||||
FutureWarning, | FutureWarning, | ||||
) | ) | ||||
@@ -724,7 +724,7 @@ class CPTDecoder(CPTPretrainedModel): | |||||
if getattr(self.config, "gradient_checkpointing", False) and self.training: | if getattr(self.config, "gradient_checkpointing", False) and self.training: | ||||
if use_cache: | if use_cache: | ||||
logger.warn( | |||||
logger.warning( | |||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " | ||||
"`use_cache=False`..." | "`use_cache=False`..." | ||||
) | ) | ||||
@@ -312,7 +312,7 @@ class BatchEncoding(UserDict): | |||||
""" | """ | ||||
if not self._encodings: | if not self._encodings: | ||||
raise ValueError("words() is not available when using Python-based tokenizers") | raise ValueError("words() is not available when using Python-based tokenizers") | ||||
logger.warn( | |||||
logger.warning( | |||||
"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " | "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " | ||||
"but more self-explanatory `BatchEncoding.word_ids()` property.", | "but more self-explanatory `BatchEncoding.word_ids()` property.", | ||||
FutureWarning, | FutureWarning, | ||||
@@ -1601,7 +1601,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " | f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " | ||||
"supported for this tokenizer. Use a model identifier or the path to a directory instead." | "supported for this tokenizer. Use a model identifier or the path to a directory instead." | ||||
) | ) | ||||
logger.warn( | |||||
logger.warning( | |||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " | f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " | ||||
"won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", | "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", | ||||
FutureWarning, | FutureWarning, | ||||
@@ -2163,7 +2163,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
# Get padding strategy | # Get padding strategy | ||||
if padding is False and old_pad_to_max_length: | if padding is False and old_pad_to_max_length: | ||||
if verbose: | if verbose: | ||||
logger.warn( | |||||
logger.warning( | |||||
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, " | "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " | ||||
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " | "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " | ||||
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific " | "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " | ||||
@@ -2184,7 +2184,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
"To pad to max length, use `padding='max_length'`." | "To pad to max length, use `padding='max_length'`." | ||||
) | ) | ||||
if old_pad_to_max_length is not False: | if old_pad_to_max_length is not False: | ||||
logger.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") | |||||
logger.warning("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") | |||||
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch | padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch | ||||
elif not isinstance(padding, PaddingStrategy): | elif not isinstance(padding, PaddingStrategy): | ||||
padding_strategy = PaddingStrategy(padding) | padding_strategy = PaddingStrategy(padding) | ||||
@@ -2196,7 +2196,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
# Get truncation strategy | # Get truncation strategy | ||||
if truncation is False and old_truncation_strategy != "do_not_truncate": | if truncation is False and old_truncation_strategy != "do_not_truncate": | ||||
if verbose: | if verbose: | ||||
logger.warn( | |||||
logger.warning( | |||||
"The `truncation_strategy` argument is deprecated and will be removed in a future version, " | "The `truncation_strategy` argument is deprecated and will be removed in a future version, " | ||||
"use `truncation=True` to truncate examples to a max length. You can give a specific " | "use `truncation=True` to truncate examples to a max length. You can give a specific " | ||||
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " | "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " | ||||
@@ -3352,7 +3352,7 @@ model_inputs["labels"] = labels["input_ids"] | |||||
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. | See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. | ||||
For a more complete example, see the implementation of `prepare_seq2seq_batch`. | For a more complete example, see the implementation of `prepare_seq2seq_batch`. | ||||
""" | """ | ||||
logger.warn(formatted_warning, FutureWarning) | |||||
logger.warning(formatted_warning, FutureWarning) | |||||
# mBART-specific kwargs that should be ignored by other models. | # mBART-specific kwargs that should be ignored by other models. | ||||
kwargs.pop("src_lang", None) | kwargs.pop("src_lang", None) | ||||
kwargs.pop("tgt_lang", None) | kwargs.pop("tgt_lang", None) | ||||