Browse Source

little change

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
3bbf27283d
17 changed files with 213 additions and 186 deletions
  1. +20
    -1
      fastNLP/core/callbacks/callback.py
  2. +49
    -19
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +9
    -5
      fastNLP/core/callbacks/early_stop_callback.py
  4. +3
    -2
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +8
    -6
      fastNLP/core/callbacks/progress_callback.py
  6. +4
    -5
      fastNLP/core/controllers/evaluator.py
  7. +2
    -4
      fastNLP/core/controllers/loops/train_batch_loop.py
  8. +3
    -4
      fastNLP/core/controllers/trainer.py
  9. +3
    -3
      fastNLP/core/controllers/utils/state.py
  10. +1
    -12
      fastNLP/core/drivers/driver.py
  11. +21
    -21
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  12. +1
    -1
      fastNLP/core/drivers/torch_driver/ddp.py
  13. +1
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  14. +27
    -79
      tests/core/controllers/test_trainer_paddle.py
  15. +45
    -23
      tests/core/drivers/paddle_driver/test_single_device.py
  16. +1
    -0
      tests/core/utils/test_utils.py
  17. +15
    -1
      tests/helpers/datasets/normal_data.py

+ 20
- 1
fastNLP/core/callbacks/callback.py View File

@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback):
if (self.larger_better and monitor_value1 > monitor_value2) or \ if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2): (not self.larger_better and monitor_value1 < monitor_value2):
better = True better = True
return better
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

+ 49
- 19
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
class CheckpointCallback(HasMonitorCallback): class CheckpointCallback(HasMonitorCallback):
def __init__( def __init__(
self, self,
monitor,
monitor:Optional[Union[str, Callable]]=None,
save_folder: Optional[Union[str, Path]] = None, save_folder: Optional[Union[str, Path]] = None,
save_every_n_epochs: Optional[int] = None, save_every_n_epochs: Optional[int] = None,
save_every_n_batches: Optional[int] = None, save_every_n_batches: Optional[int] = None,
save_last: bool = True,
save_last: bool = False,
save_topk: Optional[int] = None, save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
larger_better: bool = True, larger_better: bool = True,
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback):
model_save_fn: Optional[Callable] = None, model_save_fn: Optional[Callable] = None,
**kwargs, **kwargs,
): ):
"""
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
super().__init__(monitor=monitor, larger_better=larger_better, super().__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=save_topk is not None) must_have_monitor=save_topk is not None)
if save_folder is None: if save_folder is None:
logger.warning( logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.") "Parameter `path` is None, and we will use the current work directory to find and load your model.")
save_folder = Path.cwd() save_folder = Path.cwd()
save_folder = Path(save_folder)
if not save_folder.exists(): if not save_folder.exists():
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!")
elif save_folder.is_file(): elif save_folder.is_file():
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback):
else: else:
save_on_exception = [] save_on_exception = []


self.save_folder = Path(save_folder)
self.save_folder = save_folder
self.save_every_n_epochs = save_every_n_epochs self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_batches = save_every_n_batches self.save_every_n_batches = save_every_n_batches
self.save_last = save_last self.save_last = save_last
@@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback):
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.timestamp_path)
# 该 folder 只在保存真的要发生的时候再创建。


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
if self.save_topk is not None: if self.save_topk is not None:
@@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback):
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.")


def on_validate_end(self, trainer, results): def on_validate_end(self, trainer, results):
if len(results) == 0:
return
self._save_topk(trainer, results) self._save_topk(trainer, results)


def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback):
states['timestamp_path'] = str(self.timestamp_path.absolute()) states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = deepcopy(self._topk_model) states['_topk_model'] = deepcopy(self._topk_model)
states['save_topk'] = 0 if self.save_topk is None else self.save_topk states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor
return states return states


def on_load_checkpoint(self, trainer, states: Optional[Dict]): def on_load_checkpoint(self, trainer, states: Optional[Dict]):
timestamp_path = states['timestamp_path'] timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path): if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.") f" {self.timestamp_path.absolute()}.")
else: else:
logger.info(f"Resume to save in path: {timestamp_path}.")
logger.info(f"Resume to checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path) self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model'] _topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback):
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}." f"as {save_topk}."
self._topk_model.update(self._topk_model) self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]

self._real_monitor = states["_real_monitor"]


def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
""" """
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。


:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。 :param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback):
""" """
@property @property
def save_fn_name(self): def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save_model' return 'save_model'


@property @property
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return: :return:
""" """
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"


@property @property
def folder_prefix(self): def folder_prefix(self):
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。


:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。 :param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback):
""" """
@property @property
def save_fn_name(self): def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save' return 'save'


@property @property
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return: :return:
""" """
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"


@property @property
def folder_prefix(self): def folder_prefix(self):


+ 9
- 5
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
""" """


:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。 :param patience: 多少次 validate 不没有提升就停止。
""" """
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback):
states = { states = {
'patience': self.patience, 'patience': self.patience,
'wait': self.wait, 'wait': self.wait,
'monitor': self.monitor,
'monitor_value': self.monitor_value 'monitor_value': self.monitor_value
} }
if not callable(self._real_monitor):
states['_real_monitor'] = self._real_monitor
return states return states


def on_load_checkpoint(self, trainer, states): def on_load_checkpoint(self, trainer, states):
self.patience = states['patience'] self.patience = states['patience']
self.wait = states['wait'] self.wait = states['wait']
self.monitor = states['monitor']
self.monitor_value = float(states['monitor_value']) self.monitor_value = float(states['monitor_value'])
if '_real_monitor' in states:
self._real_monitor = states['_real_monitor']


@property
def callback_name(self): def callback_name(self):
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}'
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}'



+ 3
- 2
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback):
""" """
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。


:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 该 metric 值是否是越大越好。 :param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。


+ 8
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback):


:param print_every: 多少个 batch 更新一次显示。 :param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是 monitor 的结果越大越好。
:param format_json: 是否格式化 json 再打印
""" """
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every self.print_every = print_every
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback):


:param print_every: 多少个 batch 更新一次显示。 :param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果(
字典类型),返回一个 float 值作为 monitor 的结果。
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。 :param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印 :param format_json: 是否format json再打印
""" """


+ 4
- 5
fastNLP/core/controllers/evaluator.py View File

@@ -36,7 +36,7 @@ class Evaluator:
model, model,
dataloaders, dataloaders,
metrics: Optional[Union[Dict, Metric]] = None, metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'single',
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None, device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None, batch_step_fn: Optional[callable] = None,
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable
@@ -49,8 +49,8 @@ class Evaluator:
): ):
""" """


:param dataloaders:
:param model: :param model:
:param dataloaders:
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的
metric ,torchmetrics,allennlpmetrics等。 metric ,torchmetrics,allennlpmetrics等。
:param driver: 使用 driver 。 :param driver: 使用 driver 。
@@ -120,7 +120,8 @@ class Evaluator:


if evaluate_fn is not None and not isinstance(evaluate_fn, str): if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self._evaluate_step, self._evaluate_step_signature_fn = \
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self.evaluate_fn = evaluate_fn self.evaluate_fn = evaluate_fn


self.dataloaders = {} self.dataloaders = {}
@@ -134,8 +135,6 @@ class Evaluator:


self.driver.barrier() self.driver.barrier()


self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False)

def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
""" """
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。


+ 2
- 4
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop):
else lambda *args, **kwargs: None else lambda *args, **kwargs: None
dataloader = iter(dataloader) dataloader = iter(dataloader)
indices = None indices = None
while True:
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
try: try:
trainer.on_fetch_data_begin() trainer.on_fetch_data_begin()
batch = next(dataloader) batch = next(dataloader)
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop):
batch = trainer.move_data_to_device(batch) batch = trainer.move_data_to_device(batch)
except StopIteration: except StopIteration:
break break
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception
break
except BaseException as e: except BaseException as e:
if indices:
if indices and not isinstance(e, EarlyStopException):
logger.debug(f"The following exception happens when running on samples: {indices}") logger.debug(f"The following exception happens when running on samples: {indices}")
raise e raise e




+ 3
- 4
fastNLP/core/controllers/trainer.py View File

@@ -264,7 +264,6 @@ class Trainer(TrainerEventTrigger):
self.on_after_trainer_initialized(self.driver) self.on_after_trainer_initialized(self.driver)


self.driver.barrier() self.driver.barrier()
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True)


def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
@@ -310,7 +309,7 @@ class Trainer(TrainerEventTrigger):


self.num_batches_per_epoch = len(self.dataloader) self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs self.total_batches = self.num_batches_per_epoch * self.n_epochs
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
self.on_train_begin() self.on_train_begin()
self.driver.barrier() self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none) self.driver.zero_grad(self.set_grad_to_none)
@@ -637,6 +636,8 @@ class Trainer(TrainerEventTrigger):
:param folder: 保存断点重训 states 的文件地址; :param folder: 保存断点重训 states 的文件地址;
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置;
:param only_state_dict: 保存的 model 是否只包含了权重。
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。
""" """
self.driver.barrier() self.driver.barrier()
if isinstance(folder, str): if isinstance(folder, str):
@@ -675,8 +676,6 @@ class Trainer(TrainerEventTrigger):
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \
self.batch_idx_in_epoch
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch




+ 3
- 3
fastNLP/core/controllers/utils/state.py View File

@@ -65,10 +65,10 @@ class TrainerState:
""" """
n_epochs: Optional[int] = None # 无论如何重新算 n_epochs: Optional[int] = None # 无论如何重新算


cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = None # 断点重训
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = 0 # 断点重训


batch_idx_in_epoch: Optional[int] = None # 断点重训
batch_idx_in_epoch: Optional[int] = 0 # 断点重训


num_batches_per_epoch: Optional[int] = None # 无论如何重新算 num_batches_per_epoch: Optional[int] = None # 无论如何重新算




+ 1
- 12
fastNLP/core/drivers/driver.py View File

@@ -86,7 +86,7 @@ class Driver(ABC):
函数; 函数;


:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数;
:param fn: 调用该函数进行一次计算。
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
@@ -126,17 +126,6 @@ class Driver(ABC):
def model(self, model): def model(self, model):
self._model = model self._model = model


@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的
行为是不相同的;

:param dataloader: 需要检测的输入的 `dataloader`;
:param dataloader_name:
"""
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.")

@property @property
def optimizers(self) -> List: def optimizers(self) -> List:
r""" r"""


+ 21
- 21
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE:
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer


_reduces = { _reduces = {
'max': paddle.max,
'min': paddle.min,
'mean': paddle.mean,
'sum': paddle.sum
"max": paddle.max,
"min": paddle.min,
"mean": paddle.mean,
"sum": paddle.sum
} }


class PaddleDriver(Driver): class PaddleDriver(Driver):
@@ -254,24 +254,24 @@ class PaddleDriver(Driver):
else: else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")


num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
num_consumed_batches = states.pop("num_consumed_batches")
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict() sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
if num_consumed_samples_array is not None: if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
try:
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
pass
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
else: else:
raise RuntimeError( raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states["sampler_states"] = sampler_states


# 2. 保存模型的状态; # 2. 保存模型的状态;
if should_save_model: if should_save_model:
@@ -326,7 +326,7 @@ class PaddleDriver(Driver):
batch_size=dataloader_args.batch_size, batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
) )
sampler.load_state_dict(states['sampler_states'])
sampler.load_state_dict(states["sampler_states"])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)


# 4. 修改 trainer_state.batch_idx_in_epoch # 4. 修改 trainer_state.batch_idx_in_epoch
@@ -355,7 +355,7 @@ class PaddleDriver(Driver):
return paddle.no_grad return paddle.no_grad


@staticmethod @staticmethod
def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']):
def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]):
r""" r"""
用来将模型转移到指定的 device 上; 用来将模型转移到指定的 device 上;
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
@@ -363,7 +363,7 @@ class PaddleDriver(Driver):
if device is not None: if device is not None:
model.to(device) model.to(device)


def move_data_to_device(self, batch: 'paddle.Tensor'):
def move_data_to_device(self, batch: "paddle.Tensor"):
r""" r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
@@ -404,7 +404,7 @@ class PaddleDriver(Driver):
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank)


def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx):
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
r""" r"""
对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;




+ 1
- 1
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver):
if hasattr(model, fn): if hasattr(model, fn):
fn = getattr(model, fn) fn = getattr(model, fn)
if not callable(fn): if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None return fn, None
elif fn in {"train_step", "evaluate_step"}: elif fn in {"train_step", "evaluate_step"}:
return model, model.forward return model, model.forward


+ 1
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -199,6 +199,7 @@ class TorchDriver(Driver):
num_consumed_batches = sampler_states['num_consumed_samples'] num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else: else:
raise RuntimeError( raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')


+ 27
- 79
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,19 +1,20 @@
import pytest import pytest
import os import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass


from paddle.optimizer import Adam
from paddle.io import DataLoader

from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK


from paddle.optimizer import Adam
from paddle.io import DataLoader



from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context


@@ -48,64 +49,31 @@ class TrainerParameters:
output_mapping: Any = None output_mapping: Any = None
metrics: Any = None metrics: Any = None


# @pytest.fixture(params=[0], autouse=True)
# def model_and_optimizers(request):
# """
# 初始化单卡模式的模型和优化器
# """
# trainer_params = TrainerParameters()
# print(paddle.device.get_device())

# if request.param == 0:
# trainer_params.model = PaddleNormalModel_Classification(
# num_labels=MNISTTrainPaddleConfig.num_labels,
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension
# )
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
# train_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST("train"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# val_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST(evaluate_fn="test"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# trainer_params.train_dataloader = train_dataloader
# trainer_params.evaluate_dataloaders = val_dataloader
# trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every
# trainer_params.metrics = {"acc": Accuracy()}

# return trainer_params


@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@magic_argv_env_context @magic_argv_env_context
def test_trainer_paddle( def test_trainer_paddle(
# model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
callbacks, callbacks,
n_epochs=15,
n_epochs=2,
): ):
trainer_params = TrainerParameters() trainer_params = TrainerParameters()


trainer_params.model = PaddleNormalModel_Classification(
trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels, num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension feature_dimension=MNISTTrainPaddleConfig.feature_dimension
) )
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader( train_dataloader = DataLoader(
dataset=PaddleDataset_MNIST("train"),
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size, batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True shuffle=True
) )
val_dataloader = DataLoader( val_dataloader = DataLoader(
dataset=PaddleDataset_MNIST(mode="test"),
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size, batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True shuffle=True
) )
@@ -113,39 +81,19 @@ def test_trainer_paddle(
trainer_params.validate_dataloaders = val_dataloader trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")} trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as exc:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
evaluate_dataloaders=trainer_params.validate_dataloaders,
evaluate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
assert exc.value.code == 0
return
else:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
evaluate_dataloaders=trainer_params.validate_dataloaders,
evaluate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()

+ 45
- 23
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -224,7 +224,6 @@ class TestSetDistReproDataloder:
""" """
def setup_method(self): def setup_method(self):
self.dataset = PaddleNormalDataset(20) self.dataset = PaddleNormalDataset(20)
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
model = PaddleNormalModel_Classification_1(10, 32) model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu") self.driver = PaddleSingleDriver(model, device="cpu")
@@ -233,55 +232,59 @@ class TestSetDistReproDataloder:
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader 当dist为字符串时,此时应该返回原来的 dataloader
""" """
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert replaced_loader is self.dataloader
assert replaced_loader is dataloader


def test_set_dist_repro_dataloader_with_reproducible_true(self): def test_set_dist_repro_dataloader_with_reproducible_true(self):
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
""" """
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)


assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == self.dataloader.drop_last
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last


# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)


def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
""" """
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist assert replaced_loader.batch_sampler is dist


# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)


def test_set_dist_repro_dataloader_with_dist_sampler(self): def test_set_dist_repro_dataloader_with_dist_sampler(self):
""" """
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True) dist = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is self.dataloader)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size


# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)


def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
""" """
@@ -295,11 +298,12 @@ class TestSetDistReproDataloder:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last


# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)


def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
""" """
@@ -316,34 +320,52 @@ class TestSetDistReproDataloder:


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2 assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True


# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)


def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
""" """
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
""" """
# 迭代两个 batch # 迭代两个 batch
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。
num_consumed_batches = 2
already_seen_idx = set() already_seen_idx = set()
for idx, batch in replaced_loader:
already_seen_idx.update(batch)
if idx >= 1:
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict() sampler_states = replaced_loader.batch_sampler.state_dict()
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
print(sampler_states["data_idx"])

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

import time
time.sleep(5)


# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states) replaced_loader.batch_sampler.load_state_dict(sampler_states)
else: else:
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader): for idx, batch in enumerate(replaced_loader):
left_idxes.update(batch) left_idxes.update(batch)




+ 1
- 0
tests/core/utils/test_utils.py View File

@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters:




def test_get_fun_msg(): def test_get_fun_msg():
# 测试运行
def demo(x): def demo(x):
pass pass



+ 15
- 1
tests/helpers/datasets/normal_data.py View File

@@ -1,3 +1,6 @@
import numpy as np


class NormalIterator: class NormalIterator:
def __init__(self, num_of_data=1000): def __init__(self, num_of_data=1000):
self._num_of_data = num_of_data self._num_of_data = num_of_data
@@ -15,4 +18,15 @@ class NormalIterator:
return self._data return self._data


def __len__(self): def __len__(self):
return self._num_of_data
return self._num_of_data


class RandomDataset:
def __init__(self, num_data=10):
self.data = np.random.rand(num_data)

def __len__(self):
return len(self.data)

def __getitem__(self, item):
return self.data[item]

Loading…
Cancel
Save