Browse Source

fix bug and make shuffle automatic

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
8467cb6e41
6 changed files with 43 additions and 37 deletions
  1. +10
    -8
      fastNLP/core/callbacks/load_best_model_callback.py
  2. +5
    -7
      fastNLP/core/callbacks/progress_callback.py
  3. +8
    -7
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  4. +9
    -6
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  5. +3
    -2
      fastNLP/core/dataloaders/prepare_dataloader.py
  6. +8
    -7
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 10
- 8
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -105,14 +105,16 @@ class LoadBestModelCallback(HasMonitorCallback):

def on_train_end(self, trainer):
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
# 如果是分布式且报错了,就不要加载了,防止barrier的问题
if not (trainer.driver.is_distributed() and self.encounter_exception):
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
if not self.encounter_exception: # 防止出现死锁。
trainer.driver.barrier()


+ 5
- 7
fastNLP/core/callbacks/progress_callback.py View File

@@ -22,9 +22,10 @@ class ProgressCallback(HasMonitorCallback):
self.best_monitor_step = -1
self.best_results = None

def record_better_monitor(self, trainer):
def record_better_monitor(self, trainer, results):
self.best_monitor_step = trainer.global_forward_batches
self.best_monitor_epoch = trainer.cur_epoch_idx
self.best_results = results

def on_train_end(self, trainer):
if self.best_monitor_epoch != -1:
@@ -138,7 +139,7 @@ class RichCallback(ProgressCallback):
characters = '-'
if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'):
rule_style = 'spring_green3'
text_style = '[bold]'
@@ -154,7 +155,6 @@ class RichCallback(ProgressCallback):
self.progress_bar.console.print_json(results)
else:
self.progress_bar.print(results)
self.best_results = results

def clear_tasks(self):
for key, taskid in self.task2id.items():
@@ -222,7 +222,7 @@ class RawTextCallback(ProgressCallback):
text = ''
if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0:
@@ -234,7 +234,6 @@ class RawTextCallback(ProgressCallback):
if self.format_json:
results = json.dumps(results)
logger.info(results)
self.best_results = results

@property
def name(self): # progress bar的名称
@@ -311,7 +310,7 @@ class TqdmCallback(ProgressCallback):
text = ''
if self.monitor is not None:
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
self.record_better_monitor(trainer, results)
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0:
@@ -323,7 +322,6 @@ class TqdmCallback(ProgressCallback):
if self.format_json:
results = json.dumps(results)
logger.info(results)
self.best_results = results

def clear_tasks(self):
for key, taskid in self.task2id.items():


+ 8
- 7
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -200,7 +200,7 @@ class JittorDataLoader:
return self.cur_batch_indices


def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False,
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto",
@@ -230,7 +230,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否打乱数据集, 默认为 ``False``。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
:param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快
@@ -258,7 +259,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -267,7 +268,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
else:
dl_bundle[name] = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -279,14 +280,14 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
ds_dict = {}
for name, ds in ds_or_db.items():
if 'train' in name:
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)
else:
dl = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
@@ -296,7 +297,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
return ds_dict

elif isinstance(ds_or_db, HasLenGetitemType):
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle,
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)


+ 9
- 6
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -293,7 +293,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
@@ -326,7 +327,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=shuffle,
shuffle=True if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader,
@@ -337,7 +338,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list=return_list,
batch_sampler=batch_sampler,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader,
@@ -350,7 +351,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
for name, ds in ds_or_db.items():
if 'train' in name:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn,
@@ -359,7 +361,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn,
@@ -369,7 +371,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,

elif isinstance(ds_or_db, HasLenGetitemType):
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
batch_sampler=batch_sampler, batch_size=batch_size,
shuffle=False if shuffle is None else shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)


+ 3
- 2
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -13,7 +13,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger


def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False,
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False,
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
backend: str = 'auto'):
"""
@@ -28,7 +28,8 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。

:param batch_size: 批次大小。
:param shuffle: 是否打乱数据集。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。
:param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值:



+ 8
- 7
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -218,7 +218,7 @@ class TorchDataLoader(DataLoader):

def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16,
shuffle: bool = False,
shuffle: bool = None,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
@@ -252,7 +252,8 @@ def prepare_torch_dataloader(ds_or_db,

:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否打乱数据集, 默认为 ``False``。
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True ,
其它的为 False 。
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
默认为None, 当其不为 None 时, shuffle 参数无效。
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
@@ -290,7 +291,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
@@ -300,7 +301,7 @@ def prepare_torch_dataloader(ds_or_db,
else:
dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
@@ -316,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.items():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
@@ -326,7 +327,7 @@ def prepare_torch_dataloader(ds_or_db,
else:
dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
shuffle=False if shuffle is None else shuffle,
sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
@@ -340,7 +341,7 @@ def prepare_torch_dataloader(ds_or_db,

elif isinstance(ds_or_db, HasLenGetitemType):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,


Loading…
Cancel
Save