Browse Source

Merge remote-tracking branch 'refs/remotes/origin/dev0.8.0' into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
4cad7f548d
68 changed files with 2735 additions and 1089 deletions
  1. +182
    -24
      fastNLP/core/callbacks/callback.py
  2. +26
    -36
      fastNLP/core/callbacks/callback_events.py
  3. +9
    -1
      fastNLP/core/callbacks/callback_manager.py
  4. +8
    -22
      fastNLP/core/callbacks/checkpoint_callback.py
  5. +6
    -5
      fastNLP/core/callbacks/early_stop_callback.py
  6. +6
    -5
      fastNLP/core/callbacks/load_best_model_callback.py
  7. +3
    -1
      fastNLP/core/callbacks/progress_callback.py
  8. +14
    -4
      fastNLP/core/callbacks/utils.py
  9. +16
    -9
      fastNLP/core/collators/collator.py
  10. +23
    -17
      fastNLP/core/controllers/evaluator.py
  11. +6
    -3
      fastNLP/core/controllers/loops/train_batch_loop.py
  12. +87
    -64
      fastNLP/core/controllers/trainer.py
  13. +1
    -1
      fastNLP/core/controllers/utils/state.py
  14. +15
    -5
      fastNLP/core/controllers/utils/utils.py
  15. +1
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  16. +11
    -5
      fastNLP/core/dataset/dataset.py
  17. +2
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  18. +71
    -41
      fastNLP/core/drivers/paddle_driver/fleet.py
  19. +12
    -13
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  20. +219
    -59
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  21. +46
    -41
      fastNLP/core/drivers/paddle_driver/single_device.py
  22. +69
    -38
      fastNLP/core/drivers/paddle_driver/utils.py
  23. +17
    -26
      fastNLP/core/drivers/torch_driver/ddp.py
  24. +33
    -32
      fastNLP/core/drivers/torch_driver/dist_utils.py
  25. +5
    -13
      fastNLP/core/drivers/torch_driver/single_device.py
  26. +27
    -5
      fastNLP/core/drivers/torch_driver/torch_driver.py
  27. +15
    -15
      fastNLP/core/drivers/torch_driver/utils.py
  28. +19
    -3
      fastNLP/core/log/logger.py
  29. +2
    -3
      fastNLP/core/metrics/accuracy.py
  30. +3
    -3
      fastNLP/core/samplers/__init__.py
  31. +33
    -0
      fastNLP/core/samplers/conversion_utils.py
  32. +60
    -43
      fastNLP/core/samplers/reproducible_batch_sampler.py
  33. +44
    -28
      fastNLP/core/samplers/reproducible_sampler.py
  34. +2
    -2
      fastNLP/core/samplers/unrepeated_sampler.py
  35. +53
    -30
      fastNLP/core/samplers/utils.py
  36. +1
    -2
      fastNLP/core/utils/__init__.py
  37. +4
    -1
      fastNLP/core/utils/paddle_utils.py
  38. +77
    -2
      fastNLP/core/utils/rich_progress.py
  39. +4
    -2
      fastNLP/core/utils/torch_utils.py
  40. +117
    -37
      fastNLP/core/utils/utils.py
  41. +3
    -2
      fastNLP/envs/__init__.py
  42. +2
    -0
      fastNLP/envs/env.py
  43. +11
    -8
      fastNLP/envs/set_backend.py
  44. +1
    -2
      fastNLP/envs/set_env_on_import.py
  45. +13
    -0
      fastNLP/envs/utils.py
  46. +5
    -5
      fastNLP/io/data_bundle.py
  47. +1
    -1
      fastNLP/io/pipe/classification.py
  48. +1
    -1
      fastNLP/io/pipe/construct_graph.py
  49. +1
    -1
      fastNLP/io/pipe/pipe.py
  50. +1
    -1
      tests/core/callbacks/test_callback_events.py
  51. +2
    -2
      tests/core/callbacks/test_checkpoint_callback_torch.py
  52. +93
    -0
      tests/core/controllers/test_trainer_fleet.py
  53. +98
    -0
      tests/core/controllers/test_trainer_fleet_outside.py
  54. +25
    -0
      tests/core/controllers/test_trainer_other_things.py
  55. +43
    -3
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  56. +40
    -3
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  57. +58
    -38
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py
  58. +0
    -262
      tests/core/drivers/paddle_driver/test_paddle_driver.py
  59. +526
    -58
      tests/core/drivers/paddle_driver/test_single_device.py
  60. +54
    -2
      tests/core/drivers/paddle_driver/test_utils.py
  61. +24
    -24
      tests/core/log/test_logger.py
  62. +85
    -8
      tests/core/samplers/test_reproducible_batch_sampler.py
  63. +60
    -2
      tests/core/samplers/test_reproducible_sampler.py
  64. +3
    -3
      tests/core/samplers/test_unrepeated_sampler.py
  65. +187
    -0
      tests/core/utils/test_utils.py
  66. +8
    -2
      tests/helpers/callbacks/helper_callbacks.py
  67. +27
    -0
      tests/helpers/models/torch_model.py
  68. +14
    -17
      tests/helpers/utils.py

+ 182
- 24
fastNLP/core/callbacks/callback.py View File

@@ -10,6 +10,7 @@ from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils.utils import _check_valid_parameters_number


class Callback:
@@ -32,100 +33,225 @@ class Callback:
def on_sanity_check_end(self, trainer, sanity_check_res):
r"""
在 '预跑'检测 开始后会被触发;

:param trainer:
:param sanity_check_res: 预跑的 evaluate 结果
:return:
"""
pass

def on_train_begin(self, trainer):
r"""
在训练开始前会被触发;

:param trainer:
:return:
"""
pass

def on_train_end(self, trainer):
r"""
在训练完成后会被触发;

:param trainer:
:return:
"""
pass

def on_train_epoch_begin(self, trainer):
r"""
在训练过程中的每一个 epoch 开始前会被触发;

:param trainer:
:return:
"""
pass

def on_train_epoch_end(self, trainer):
r"""
在训练过程中的每一个 epoch 完成后会被触发;
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。

:param trainer:
:return:
"""
pass

def on_fetch_data_begin(self, trainer):
r"""
在训练过程中拿到当前的具体的一个 batch 前会被触发;
在训练过程中准备取出下一个 batch 的数据时触发

:param trainer:
:return:
"""
pass

def on_fetch_data_end(self, trainer):
r"""
在训练过程中拿到当前的具体的一个 batch 后会被触发;
在训练过程中拿到当前的 batch 数据后会被触发;

:param trainer:
:return:
"""
pass

def on_train_batch_begin(self, trainer, batch, indices=None):
def on_train_batch_begin(self, trainer, batch, indices):
r"""
在训练过程中开始具体的一个 batch 前会被触发;
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。

:param trainer: `fastNLP.Trainer`
:param batch: 当前正在运行的一个 batch;
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据;
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据
"""
pass

def on_train_batch_end(self, trainer):
"""
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会
执行。

:param trainer:
:return:
"""
pass

def on_exception(self, trainer, exception):
"""
在训练过程遇到异常时调用。

:param trainer:
:param exception: 遭遇的异常。
:return:
"""
pass

def on_save_model(self, trainer):
"""
当将要保存模型时调用,此刻模型还未保存。

:param trainer:
:return:
"""
pass

def on_load_model(self, trainer):
"""
当将要加载模型时调用,此刻模型还未加载。

:param trainer:
:return:
"""
pass

def on_save_checkpoint(self, trainer) -> Dict:
"""
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常
工作的状态;而不应该让该函数去判断两个 callback 是否一样;
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。

:param trainer:
:return:
"""
pass

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
r"""
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None;
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint()
的返回值。

:param trainer:
:param states:
:return:
"""
pass

def on_before_backward(self, trainer, outputs):
"""
在 backward 前执行。

:param trainer:
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。
:return:
"""
pass

def on_after_backward(self, trainer):
"""
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步,
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。

:param trainer:
:return:
"""
pass

def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_after_optimizers_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_before_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_after_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_validate_begin(self, trainer):
"""
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。

:param trainer:
:return:
"""
pass

def on_validate_end(self, trainer, results):
"""
结束 validate 时调用,并把 validate 的结果传入。

:param trainer:
:param results:
:return:
"""
pass

@property
def callback_name(self):
"""
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。

:return:
"""
return self.__class__.__name__


@@ -174,7 +300,11 @@ class HasMonitorCallback(Callback):
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
self.monitor = str(monitor) if monitor is not None else None
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
@@ -197,24 +327,33 @@ class HasMonitorCallback(Callback):
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")

def get_monitor_value(self, results:Dict)->float:
def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。

:param results:
:return:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
return 0
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if self._real_monitor != use_monitor: # 发生了替换需要打印
logger.warning(
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor
return monitor_value

@@ -222,14 +361,33 @@ class HasMonitorCallback(Callback):
"""
检测 monitor_value 是否是更好的

:param monitor_value:
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value > self.monitor_value) or \
(not self.larger_better and monitor_value < self.monitor_value):
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
if keep_if_better:
self.monitor_value = monitor_value
return better

+ 26
- 36
fastNLP/core/callbacks/callback_events.py View File

@@ -74,28 +74,30 @@ class EventEnum(_SingleEventState, Enum):

@unique
class Events(EventEnum):
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized"
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin"
ON_SANITY_CHECK_END = "on_sanity_check_end"
ON_TRAIN_BEGIN = "on_train_begin"
ON_TRAIN_END = "on_train_end"
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin"
ON_TRAIN_EPOCH_END = "on_train_epoch_end"
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin"
ON_FETCH_DATA_END = "on_fetch_data_end"
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin"
ON_TRAIN_BATCH_END = "on_train_batch_end"
ON_EXCEPTION = "on_exception"
ON_SAVE_MODEL = "on_save_model"
ON_LOAD_MODEL = "on_load_model"
ON_SAVE_CHECKPOINT = "on_save_checkpoint"
ON_LOAD_CHECKPOINT = "on_load_checkpoint"
ON_BEFORE_BACKWARD = "on_before_backward"
ON_AFTER_BACKWARD = "on_after_backward"
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step"
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad"
ON_VALIDATE_BEGIN = "on_validate_begin"
ON_VALIDATE_END = "on_validate_end"
on_after_trainer_initialized = "on_after_trainer_initialized"
on_sanity_check_begin = "on_sanity_check_begin"
on_sanity_check_end = "on_sanity_check_end"
on_train_begin = "on_train_begin"
on_train_end = "on_train_end"
on_train_epoch_begin = "on_train_epoch_begin"
on_train_epoch_end = "on_train_epoch_end"
on_fetch_data_begin = "on_fetch_data_begin"
on_fetch_data_end = "on_fetch_data_end"
on_train_batch_begin = "on_train_batch_begin"
on_train_batch_end = "on_train_batch_end"
on_exception = "on_exception"
on_save_model = "on_save_model"
on_load_model = "on_load_model"
on_save_checkpoint = "on_save_checkpoint"
on_load_checkpoint = "on_load_checkpoint"
on_before_backward = "on_before_backward"
on_after_backward = "on_after_backward"
on_before_optimizers_step = "on_before_optimizers_step"
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_validate_begin = "on_validate_begin"
on_validate_end = "on_validate_end"


class EventsList:
@@ -169,20 +171,8 @@ class Filter:
self.num_called += 1

# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑;
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中;

# 提取参数的逻辑;
trainer = kwargs.get("trainer", None)

if trainer is None and len(args) > 0:
trainer = args[0]
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer
# 也会调用这个 module,但是 Controller 不会;
param = (self, trainer)
else:
param = (self, )
if self._filter(*param):
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)



+ 9
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -278,13 +278,21 @@ class CallbackManager:
pass

@_transfer
def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
pass

@_transfer
def on_after_optimizers_step(self, trainer, optimizers):
pass

@_transfer
def on_before_zero_grad(self, trainer, optimizers):
pass

@_transfer
def on_after_zero_grad(self, trainer, optimizers):
pass

@_transfer
def on_validate_begin(self, trainer):
pass


+ 8
- 22
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -10,12 +10,10 @@ from copy import deepcopy


import fastNLP
from .callback import Callback, HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from .callback import HasMonitorCallback
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection


class CheckpointCallback(HasMonitorCallback):
@@ -167,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback):
"""
if self.save_topk is not None:
monitor_value = self.get_monitor_value(results=results)
if monitor_value is None:
return
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{monitor_value}"

@@ -178,8 +178,7 @@ class CheckpointCallback(HasMonitorCallback):
else:
_least_valuable_model = (min if self.larger_better else max)(self._topk_model,
key=lambda x: self._topk_model[x])
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]):
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]):
self._topk_model[folder_name] = monitor_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
@@ -208,21 +207,6 @@ class CheckpointCallback(HasMonitorCallback):
**self.kwargs
)

def _get_validate_metric(self, res: Dict):
"""
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor);
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值;
:param res:
:return:
"""
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res)
if self._real_monitor != use_monitor:
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
self._real_monitor = use_monitor

return value

@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")
@@ -248,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback):
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -295,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback):
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。


+ 6
- 5
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -2,17 +2,18 @@ __all__ = [
'EarlyStopCallback'
]

from typing import Dict
from typing import Dict, Union, Callable

from .callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException


class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10):
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
"""

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。
"""
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback):
self.patience = patience

def on_validate_end(self, trainer, results):
if len(results)==0:
return
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
self.wait = 0
else:


+ 6
- 5
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -3,7 +3,7 @@ __all__ = [
]

import os
from typing import Optional, Callable
from typing import Optional, Callable, Union
from .callback import HasMonitorCallback
from io import BytesIO
import shutil
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call


class LoadBestModelCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True,
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True,
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None,
model_load_fn:Optional[Callable] = None,
delete_after_train:bool = True):
"""
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback):
self.get_monitor_value(sanity_check_res)

def on_validate_end(self, trainer, results):
if len(results)==0:
return
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,


+ 3
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback):
:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
"""
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback):

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果(
字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
"""


+ 14
- 4
fastNLP/core/callbacks/utils.py View File

@@ -1,9 +1,10 @@
from typing import Optional
from typing import Optional, Union
from fastNLP.core.log.logger import logger
from difflib import SequenceMatcher
from fastNLP.core.utils.utils import _get_fun_msg


def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float):
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float):
"""
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行
匹配。
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(
:param monitor:
:param real_monitor:
:param res:
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有
找到对应的 monitor
"""
if len(res)==0:
return monitor, 0
return monitor, None

if callable(monitor):
try:
monitor_value = monitor(res)
except BaseException as e:
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.")
raise e
return monitor, monitor_value

if monitor in res:
return monitor, res[monitor]


+ 16
- 9
fastNLP/core/collators/collator.py View File

@@ -5,7 +5,7 @@ __all__ = [


from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Callable, Union
from typing import Any, Dict, List, Callable, Union, Tuple
from numbers import Number
import warnings

@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception):
self.field_name = field_name # 标示当前 field 的名称


def _get_ele_type_and_dim(cell: Any, dim=0):
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]:
r"""
识别cell的类别与dimension的数量

@@ -197,7 +197,7 @@ class _MultiCollator:
collator.set_input(*field_names)
flag = False
if flag:
warnings.warn("AutoCollator is remove, set_input is unavailable!!")
warnings.warn("AutoCollator is removed, set_input is unavailable!!")
return self


@@ -206,7 +206,7 @@ class AutoCollator(Collator):
def __init__(self, as_numpy: bool):
super(AutoCollator, self).__init__()
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
self.need_inputs = [] # 需要的 field name
self.need_inputs = set() # 需要的 field name
self.field_dtypes = None # 每列数据单元的 dtype 类型
self.field_dims = None # 每列数据单元维度
self.as_numpy = as_numpy
@@ -214,10 +214,17 @@ class AutoCollator(Collator):
def __call__(self, ins_lst: List[Dict]) -> dict:
if len(self.need_inputs) == 0:
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的

# 第一种情况,设置了 set_input 的值
# 第二种情况, 根据数据的类型的判断是否 padding
if self.field_dtypes is None and self.field_dims is None:
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0])
field_dtypes, field_dims = {}, {}
for key, value in ins_lst[0].items():
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None:
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value)
self.field_dtypes = field_dtypes
self.field_dims = field_dims

pack_ins_lst, pad_ins_lst = {field_name: []
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
@@ -233,13 +240,13 @@ class AutoCollator(Collator):

if len(self.pad_field_value.keys()) > 0:
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
drop_field_names = []
non_pad_field_names = []
for k, v in self.pad_field_value.items():
if v is None:
drop_field_names.append(k)
non_pad_field_names.append(k)

# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
for field_name in drop_field_names:
for field_name in non_pad_field_names:
field_array = pack_ins_lst.pop(field_name)
pad_ins_lst[field_name] = np.array(field_array)

@@ -269,7 +276,7 @@ class AutoCollator(Collator):

def set_input(self, *field_names):
for field_name in field_names:
self.need_inputs.append(field_name)
self.need_inputs.add(field_name)


def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):


+ 23
- 17
fastNLP/core/controllers/evaluator.py View File

@@ -11,11 +11,12 @@ __all__ = [
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from .loops import Loop, EvaluateBatchLoop
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
match_and_substitute_params, f_rich_progress
from fastNLP.core.metrics import Metric
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.core.log import logger


@@ -38,10 +39,11 @@ class Evaluator:
driver: Union[str, Driver] = 'single',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
mode: str = "validate",
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
fp16: Optional[bool] = False,
model_wo_auto_param_call: bool = False,
fp16: bool = False,
verbose: int = 1,
**kwargs
):
@@ -61,6 +63,9 @@ class Evaluator:
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
:param fp16: 是否使用 fp16 。
:param verbose: 是否打印 evaluate 的结果。
:param kwargs:
@@ -83,13 +88,13 @@ class Evaluator:
self.model = model
self.metrics = metrics

self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs)
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)

self.device = device
self.verbose = verbose

assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \
"two parameters."
if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn

self.mode = mode
@@ -131,6 +136,7 @@ class Evaluator:
if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'

self.driver.check_evaluator_mode(self.mode)
self.driver.barrier()

def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
@@ -150,8 +156,6 @@ class Evaluator:
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."

self.driver.check_evaluator_mode(self.mode)

if self.mode == 'validate':
assert self.driver.has_validate_dataloaders()
else:
@@ -219,7 +223,6 @@ class Evaluator:
def remove_progress_bar(self, dataloader_name):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh() # 使得最终的bar可以消失
delattr(self, '_rich_task_id')
elif self.progress_bar == 'raw':
desc = 'Evaluation ends'
@@ -230,7 +233,6 @@ class Evaluator:
def finally_progress_bar(self):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh()
delattr(self, '_rich_task_id')

@property
@@ -355,20 +357,24 @@ class _MetricsWrapper:
if is_dataclass(outputs):
outputs = dataclass_to_dict(outputs)
for metric in self._metrics:
args = []
if not isinstance(batch, dict):
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly"
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.")
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
else:
args.append(batch)
if not isinstance(outputs, dict):
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly"
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
if isinstance(metric, Metric):
auto_param_call(metric.update, batch, outputs)
# 这样在 auto_param_call 报错的时候才清晰。
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
elif _is_torchmetrics_metric(metric):
auto_param_call(metric.update, batch, outputs)
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
elif _is_allennlp_metric(metric):
auto_param_call(metric.__call__, batch, outputs)
auto_param_call(metric.__call__, outputs, *args)
elif _is_paddle_metric(metric):
res = auto_param_call(metric.compute, batch, outputs)
res = auto_param_call(metric.compute, outputs, *args)
metric.update(res)

def reset(self):


+ 6
- 3
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -7,6 +7,7 @@ from typing import Optional, Callable
from .loop import Loop
from fastNLP.core.log import logger
from fastNLP.core.utils import match_and_substitute_params
from fastNLP.core.utils.exceptions import EarlyStopException


class TrainBatchLoop(Loop):
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop):
try:
trainer.on_fetch_data_begin()
batch = next(dataloader)
batch = match_and_substitute_params(trainer.input_mapping, batch)
indices = get_batch_indices()
batch = trainer.move_data_to_device(batch)
trainer.on_fetch_data_end()
batch = match_and_substitute_params(trainer.input_mapping, batch)
batch = trainer.move_data_to_device(batch)
except StopIteration:
break
except BaseException as e: # TODO 把这里的信息写入进去
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception
break
except BaseException as e:
if indices:
logger.debug(f"The following exception happens when running on samples: {indices}")
raise e


+ 87
- 64
fastNLP/core/controllers/trainer.py View File

@@ -14,6 +14,7 @@ __all__ = [

from .loops import Loop, TrainBatchLoop
from .utils import State, TrainerState
from .utils.utils import check_validate_every
from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME
@@ -42,15 +44,16 @@ class Trainer(TrainerEventTrigger):
validate_dataloaders=None,
batch_step_fn: Optional[Callable] = None,
validate_batch_step_fn: Optional[Callable] = None,
validate_mode: str = "validate",
validate_mode: Union[str, callable] = 'validate',
callbacks: Union[List[Callback], Callback, None] = None,
metrics: Optional[dict] = None,
validate_every: Optional[Union[int, callable]] = -1,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1,
fp16: bool = False,
monitor: str = None,
monitor: Union[str, callable] = None,
larger_better: bool = True,
marker: Optional[str] = None,
**kwargs
@@ -89,11 +92,8 @@ class Trainer(TrainerEventTrigger):
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch
之后,每个 epoch 结束进行一次 validate 。
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
@@ -102,12 +102,15 @@ class Trainer(TrainerEventTrigger):
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数;
@@ -126,20 +129,21 @@ class Trainer(TrainerEventTrigger):
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。

"""

# TODO 是不是可以加一个参数让用户现在关掉参数匹配。
self.marker = marker
self.model = model
self.driver_name = driver
self.marker = marker
if isinstance(driver, str):
self.driver_name = driver
else:
self.driver_name = driver.__class__.__name__
self.device = device
self.optimizers = optimizers
self.fp16 = fp16
self.input_mapping = input_mapping
self.output_mapping = output_mapping

assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \
"two parameters."
self.batch_step_fn = batch_step_fn
if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True)
else:
self.check_batch_step_fn = lambda *args, **kwargs: ...
@@ -155,6 +159,8 @@ class Trainer(TrainerEventTrigger):
elif accumulation_steps < 0:
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
self.accumulation_steps = accumulation_steps

# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
self.driver = choose_driver(
model=model,
driver=driver,
@@ -171,6 +177,7 @@ class Trainer(TrainerEventTrigger):
validate_every=validate_every,
input_mapping=input_mapping,
output_mapping=output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps,
fp16=fp16,
marker=marker,
@@ -212,17 +219,11 @@ class Trainer(TrainerEventTrigger):
if metrics is not None and validate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")

# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来;
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次;
self.evaluator = None
self.epoch_validate = lambda *args, **kwargs: ...
self.step_validate = lambda *args, **kwargs: ...
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None:
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")

check_validate_every(validate_every)
self.evaluator = Evaluator(
model=model,
dataloaders=validate_dataloaders,
@@ -239,16 +240,6 @@ class Trainer(TrainerEventTrigger):
progress_bar=kwargs.get('progress_bar', 'auto')
)

if callable(validate_every):
self._step_validate_filter = Filter(filter_fn=validate_every)
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, "
"and in this way, the kind of controlling frequency is depending on the 'step'.")
elif validate_every < 0:
self._epoch_validate_filter = Filter(every=-validate_every)
else:
# validate_every > 0
self._step_validate_filter = Filter(every=validate_every)

self.metrics = metrics
self.validate_every = validate_every

@@ -317,6 +308,8 @@ class Trainer(TrainerEventTrigger):

try:
while self.cur_epoch_idx < self.n_epochs:
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
self.driver.set_model_mode("train")
self.on_train_epoch_begin()
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx)
@@ -345,31 +338,37 @@ class Trainer(TrainerEventTrigger):
raise e

def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None:
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None:
trainer.on_validate_begin()
_validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res)

self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))

def step_validate(self):
"""
在每个 batch 结束后调用,根据设置执行 evaluate 。

:return:
"""
if self.evaluator is not None:
if callable(self.validate_every):
self.step_validate = self._step_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
elif self.validate_every < 0:
self.epoch_validate = self._epoch_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
else:
# validate_every > 0
self.step_validate = self._step_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
if self.validate_every(self):
self.run_evaluate()
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0:
self.run_evaluate()

def epoch_validate(self):
"""
在每个 epoch 结束后调用,根据设置执行 evaluate 。

:return:
"""
if self.evaluator is not None:
if isinstance(self.validate_every, int) and self.validate_every < 0:
validate_every = -self.validate_every
if self.cur_epoch_idx % validate_every == 0:
self.run_evaluate()

def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
r"""
@@ -400,9 +399,8 @@ class Trainer(TrainerEventTrigger):

def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \
"`Callback` with the same callback time."
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
_check_valid_parameters_number(fn, callback_fn_args)
return fn

return wrapper
@@ -431,9 +429,11 @@ class Trainer(TrainerEventTrigger):

2. 函数作用
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") /
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad") /
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad")
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为;

@@ -443,10 +443,12 @@ class Trainer(TrainerEventTrigger):
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop';
"""
if check_mode:
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
else:
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
_not_called_callback_fns = []
for each_callback_fn in callbacks:
if each_callback_fn in self.callback_manager.callback_fns:
@@ -498,8 +500,6 @@ class Trainer(TrainerEventTrigger):

@driver.setter
def driver(self, driver: Driver):
driver.trainer = self
driver.model = self.model
self._driver = driver

@property
@@ -591,7 +591,9 @@ class Trainer(TrainerEventTrigger):
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态;
# 2. trainer_state;
states = {"callback_states": self.on_save_checkpoint(),
"trainer_state": self.trainer_state.state_dict()}
"trainer_state": self.trainer_state.state_dict(),
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0)
}

# 3. validate filter state;
if self.evaluator is not None:
@@ -668,6 +670,10 @@ class Trainer(TrainerEventTrigger):
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \
self.batch_idx_in_epoch
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch

# 5. 恢复所有 callback 的状态;
self.on_load_checkpoint(states["callback_states"])
@@ -692,13 +698,15 @@ class Trainer(TrainerEventTrigger):

def zero_grad(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_zero_grad(self.driver.optimizers)
self.on_before_zero_grad(self.optimizers)
self.driver.zero_grad(self.set_grad_to_none)
self.on_after_zero_grad(self.optimizers)

def step(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_optimizer_step(self.driver.optimizers)
self.on_before_optimizers_step(self.optimizers)
self.driver.step()
self.on_after_optimizers_step(self.optimizers)

def move_data_to_device(self, batch):
return self.driver.move_data_to_device(batch)
@@ -796,4 +804,19 @@ class Trainer(TrainerEventTrigger):
def total_batches(self, total_batches: int):
self.trainer_state.total_batches = total_batches

""" driver property """

@property
def model_device(self):
return self.driver.model_device

@property
def data_device(self):
return self.driver.data_device








+ 1
- 1
fastNLP/core/controllers/utils/state.py View File

@@ -60,7 +60,7 @@ class TrainerState:
cur_epoch_idx: 当前正在运行第几个 epoch;
global_forward_batches: 当前模型总共 forward 了多少个 step;
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
total_batches: 每一个 epoch 会 forward 多少个 step;
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
"""
n_epochs: Optional[int] = None # 无论如何重新算


+ 15
- 5
fastNLP/core/controllers/utils/utils.py View File

@@ -1,8 +1,9 @@
from collections.abc import Iterator
import inspect
from typing import Dict

from fastNLP.core.callbacks import CallbackManager
from .state import TrainerState
from fastNLP.core.utils.utils import _check_valid_parameters_number


class TrainerEventTrigger:
@@ -68,12 +69,18 @@ class TrainerEventTrigger:
def on_after_backward(self):
self.callback_manager.on_after_backward(self)

def on_before_optimizer_step(self, optimizers):
self.callback_manager.on_before_optimizer_step(self, optimizers)
def on_before_optimizers_step(self, optimizers):
self.callback_manager.on_before_optimizers_step(self, optimizers)

def on_after_optimizers_step(self, optimizers):
self.callback_manager.on_after_optimizers_step(self, optimizers)

def on_before_zero_grad(self, optimizers):
self.callback_manager.on_before_zero_grad(self, optimizers)

def on_after_zero_grad(self, optimizers):
self.callback_manager.on_after_zero_grad(self, optimizers)

def on_validate_begin(self):
self.callback_manager.on_validate_begin(self)

@@ -119,5 +126,8 @@ class _TruncatedDataLoader:
return getattr(self.dataloader, item)




def check_validate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

+ 1
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader):
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, as_numpy: bool = False) -> None:
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None:
"""

:param dataset: 实现了__getitem__和__len__的数据容器


+ 11
- 5
fastNLP/core/dataset/dataset.py View File

@@ -178,10 +178,11 @@ class DataSet:
elif isinstance(idx, slice):
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
data_set = DataSet()
dataset = DataSet()
for field_name, field in self.field_arrays.items():
data_set.add_field(field_name=field_name, fields=field.content[idx])
return data_set
dataset.add_field(field_name=field_name, fields=field.content[idx])
dataset.collate_fns = deepcopy(self.collate_fns)
return dataset
elif isinstance(idx, str):
if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx))
@@ -192,6 +193,7 @@ class DataSet:
assert isinstance(i, int), "Only int index allowed."
instance = self[i]
dataset.append(instance)
dataset.collate_fns = deepcopy(self.collate_fns)
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -674,6 +676,8 @@ class DataSet:
dev_set.append(self[idx])
for idx in train_indices:
train_set.append(self[idx])
dev_set.collate_fns = deepcopy(self.collate_fns)
train_set.collate_fns = deepcopy(self.collate_fns)

return dev_set, train_set

@@ -788,13 +792,14 @@ class DataSet:

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names: dataset存在的field_name
:param val: 默认为0
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。
:return:
"""
# TODO 不能为空
for field_name in field_names:
self.collate_fns.set_pad_val(field_name, val=val)

@@ -805,6 +810,7 @@ class DataSet:
:param field_names:
:return:
"""
#
self.collate_fns.set_input(*field_names)

def get_collator(self) -> _MultiCollator:


+ 2
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -66,7 +66,7 @@ class JittorDriver(Driver):
if mode == "validate":
if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"):
logger.warning(
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
@@ -74,7 +74,7 @@ class JittorDriver(Driver):
else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"'test_step'.")



+ 71
- 41
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -10,6 +10,8 @@ from .utils import (
_MODE_PARAMETER,
get_device_from_visible,
reset_seed,
replace_sampler,
replace_batch_sampler,
)

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@@ -19,8 +21,17 @@ from fastNLP.core.utils import (
paddle_move_data_to_device,
is_in_paddle_dist,
)
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.samplers import (
RandomBatchSampler,
ReproducibleSampler,
ReproducibleBatchSampler,
RandomSampler,
UnrepeatedSampler,
UnrepeatedSequentialSampler,
re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler,
)
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -93,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver):
# 我们就直接将 model_device 置为 None;
self._model_device = None

def _running_fn_(batch, step_fn, signature_fn):
if isinstance(batch, Dict):
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)
@@ -105,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver):
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
# self._train_signature_fn = model.forward
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "validate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
# self._validate_signature_fn = model.forward
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
@@ -235,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver):
"""
if self.local_rank == 0:
# 是 rank0 的话,则拉起其它子进程
print("in launcher")
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
launcher.launch()
# 设置参数和初始化分布式环境
@@ -253,7 +261,6 @@ class PaddleFleetDriver(PaddleDriver):
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要
根据 paddle 设置的环境变量来获得各种属性
"""
print("set_from_env")
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()

@@ -267,9 +274,9 @@ class PaddleFleetDriver(PaddleDriver):
**self._fleet_kwargs
)

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN})
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE})
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST})
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True

@@ -312,67 +319,90 @@ class PaddleFleetDriver(PaddleDriver):
def test_step(self, batch):
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist
return dataloader

# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员
# 因此用 type() 进行严格的判断
if type(dataloader.batch_sampler) == BatchSampler:
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler)
else:
shuffle = dataloader.batch_sampler.shuffle
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
"control.")
else:
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
if isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
return dataloader
# trainer
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
dataloader.batch_sampler.sampler.set_distributed(
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return dataloader
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, sampler)
else:
sampler = RandomSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))
dataset=args.dataset,
shuffle=args.shuffle,
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
dataloader.batch_sampler.sampler = sampler
return dataloader
return replace_sampler(dataloader, sampler)
# evaluator
elif dist == "unrepeatdist":
sampler = UnrepeatedRandomSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))
)
args = self.get_dataloader_args(dataloader)
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)
else:
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank
)
dataloader.batch_sampler.sampler = sampler
return dataloader
return replace_sampler(dataloader, sampler)
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")



+ 12
- 13
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
if driver not in {"paddle", "fleet"}:
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].")

cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
# 优先级 user > cuda
# 判断单机情况 device 的合法性
# 分布式情况下通过 world_device 判断
if user_visible_devices != "":
_could_use_device_num = len(user_visible_devices.split(","))
elif cuda_visible_devices is not None:
_could_use_device_num = len(cuda_visible_devices.split(","))
else:
_could_use_device_num = paddle.device.cuda.device_count()
if user_visible_devices is None:
raise RuntimeError("This situation cannot happen, please report a bug to us.")
_could_use_device_num = len(user_visible_devices.split(","))
if isinstance(device, int):
if device < 0 and device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
# if device >= _could_use_device_num:
# raise ValueError("The gpu device that parameter `device` specifies is not existed.")
device = f"gpu:{device}"
if device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
if device != -1:
device = f"gpu:{device}"
else:
device = list(range(_could_use_device_num))
elif isinstance(device, Sequence) and not isinstance(device, str):
device = list(set(device))
for each in device:
@@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.")
elif each < 0:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
elif each >= _could_use_device_num:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
" the available gpu number.")
if len(device) == 1:
# 传入了 [1] 这样的,视为单卡。
device = device[0]


+ 219
- 59
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -1,21 +1,36 @@
import os
import random
from typing import Union, Optional, Callable, Dict
from typing import Union, Optional, Dict
from pathlib import Path
from functools import partial
from dataclasses import dataclass

import numpy as np

from .utils import _build_fp16_env
from .utils import _build_fp16_env, optimizer_state_to_device
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS
from fastNLP.envs import (
FASTNLP_SEED_WORKERS,
FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME,
FASTNLP_GLOBAL_RANK,
rank_zero_call,
)
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler

if _NEED_IMPORT_PADDLE:
import paddle
from paddle.io import DataLoader, IterableDataset
from paddle.io import (
DataLoader,
IterableDataset,
Dataset,
Sampler,
BatchSampler,
RandomSampler,
)
from paddle.optimizer import Optimizer

_reduces = {
@@ -41,6 +56,9 @@ class PaddleDriver(Driver):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def zero_grad(self, set_to_none: bool = False):
r"""
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
@@ -48,8 +66,8 @@ class PaddleDriver(Driver):

:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。
"""
# if set_to_none:
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
if set_to_none:
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
for optimizer in self.optimizers:
optimizer.clear_grad()

@@ -69,6 +87,8 @@ class PaddleDriver(Driver):
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
if isinstance(dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if dataloader.batch_sampler is None and dataloader.batch_size is None:
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
@@ -79,6 +99,9 @@ class PaddleDriver(Driver):
f"type, not {type(each_dataloader)}.")
if isinstance(each_dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None:
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
f"`batch_sampler` and `batch_size` should be set.")

@staticmethod
def _check_optimizer_legality(optimizers):
@@ -110,7 +133,7 @@ class PaddleDriver(Driver):
else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for"
"'test_step'.")

@@ -153,45 +176,55 @@ class PaddleDriver(Driver):
getattr(self.model, mode)()

@rank_zero_call
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs):
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数;

:param filepath: 保存文件的文件位置(需要包括文件名);
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效;
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path);
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。
:param kwargs:
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save`
的文档:
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save
:return:
"""
if model_save_fn is not None:
model_save_fn(filepath)
model = self.unwrap_model()
if isinstance(filepath, Path):
filepath = str(filepath)
if only_state_dict:
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(states, filepath)
else:
model = self.unwrap_model()
if only_state_dict:
paddle.save(model.state_dict(), filepath)
else:
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.")
paddle.jit.save(model, filepath, input_spec)
# paddle 在保存整个模型时需要传入额外参数
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
paddle.jit.save(model, filepath, input_spec)

@staticmethod
@rank_zero_call
def load_model(filepath: str, load_dict: bool = True):
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数;

:param filepath: 需要被加载的对象的文件位置(需要包括文件名);
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时,
即保存了整个模型时,这个参数必须也为False
:return: 返回加载指定文件后的结果;
:param only_state_dict: 是否加载state_dict,默认为True。
:param kwargs:
:return:
"""
if load_dict:
return paddle.load(filepath)
else:
return paddle.jit.load(filepath)
model = self.unwrap_model()
if isinstance(filepath, Path):
filepath = str(filepath)
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。
dirname, filename = os.path.split(filepath)
if not only_state_dict and dirname == "":
# 如果传入的是单个文件,则加上相对路径
filepath = os.path.join(".", filepath)
model.load_dict(paddle.load(filepath))

@rank_zero_call
def save(self, folder, states: Dict):
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
r"""
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict;
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver
@@ -203,48 +236,114 @@ class PaddleDriver(Driver):
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的
传入的值保持一致。
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
:return:
"""
# 1. 保存模型的状态;
model = self.unwrap_model()
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失;
states["model_state_dict"] = model_state_dict
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;

# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")

num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')

# 2. 保存 optimizers 的状态;
# 2. 保存模型的状态;
if should_save_model:
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
if only_state_dict:
logger.debug("Save model state dict.")
else:
logger.debug("Save model.")

# 3. 保存 optimizers 的状态;
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizer_state = optimizer.state_dict()
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()}
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
states["optimizers_state_dict"] = optimizers_state_dict

paddle.save(states, folder)

def load(self, filepath) -> Dict:
r"""
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等;
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。
因此 save 函数和 load 函数的接受和返回值应该是对应的;

该函数需要在所有 rank 上执行。
logger.debug("Save optimizer state dict.")
states["optimizers_state_dict"] = optimizers_state_dict
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))

:param filepath: 保存断点重训的状态的文件名;
:return: 需要返回 save 函数输入的 states 内容;
"""
states = paddle.load(filepath)
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))

# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
for i in range(len(self.optimizers)):
optimizer: paddle.optimizer.Optimizer = self.optimizers[i]
optimizer: Optimizer = self.optimizers[i]
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
logger.debug("Load optimizer state dict.")

# 2. 加载模型状态;
model = self.unwrap_model()
model.load_dict(states["model_state_dict"])
if should_load_model:
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
if only_state_dict:
logger.debug("Load model state dict.")
else:
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler;
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch

states["batch_idx_in_epoch"] = batch_idx_in_epoch

self.barrier()
return states

def get_evaluate_context(self):
@@ -282,7 +381,7 @@ class PaddleDriver(Driver):
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
"""
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
global_rank = rank if rank is not None else rank_zero_call.rank
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
# TODO gpu
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed()
# back out the base seed so we can use all the bits
@@ -313,3 +412,64 @@ class PaddleDriver(Driver):
"""
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
dataloader.batch_sampler.set_epoch(cur_epoch_idx)

@staticmethod
def get_dataloader_args(dataloader: "DataLoader"):
"""
获取 dataloader 的 shuffle 和 drop_last 属性;
"""

@dataclass
class Res:
dataset: Optional[Dataset] = None
batch_sampler: Optional[BatchSampler] = None
sampler: Optional[Sampler] = None
batch_size: Optional[int] = None
shuffle: Optional[bool] = None
drop_last: Optional[bool] = None

res = Res()

# paddle 的 DataLoader 一定会有 dataset 属性;
res.dataset = dataloader.dataset

if dataloader.batch_sampler is not None:
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None
res.batch_sampler = dataloader.batch_sampler
if hasattr(dataloader.batch_sampler, "batch_size"):
res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
else:
dataloader_iter = iter(dataloader)
pre_sample = next(dataloader_iter)
res.batch_size = pre_sample.shape[0]

if hasattr(dataloader.batch_sampler, "sampler"):
res.sampler = dataloader.batch_sampler.sampler
if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler
if hasattr(batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(batch_sampler.sampler, RandomSampler):
res.shuffle = True
else:
res.shuffle = False
else:
res.sampler = None
res.shuffle = False

if hasattr(dataloader.batch_sampler, "drop_last"):
res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
else:
res.drop_last = False

return res

+ 46
- 41
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -2,6 +2,7 @@ import os
from typing import Optional, Dict, Union

from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.utils import (
@@ -10,7 +11,12 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproducibleSampler,
re_instantiate_sampler,
)
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -22,16 +28,13 @@ __all__ = [
]

class PaddleSingleDriver(PaddleDriver):
def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs):
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs):
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs)

if device is None:
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")

if isinstance(device, int):
self.model_device = get_paddle_gpu_str(device)
else:
self.model_device = device
self.model_device = get_paddle_gpu_str(device)

self.local_rank = 0
self.global_rank = 0
@@ -93,18 +96,18 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward

def setup(self):
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
device_id = get_paddle_device_id(self.model_device)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
device_id = user_visible_devices.split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")
device = self.model_device
if device != "cpu":
device_id = get_paddle_device_id(device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
@@ -118,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver):
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
@@ -133,38 +136,40 @@ class PaddleSingleDriver(PaddleDriver):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0`

:return: 将移动到指定机器上的 batch 对象返回;
"""
return paddle_move_data_to_device(batch, "gpu:0")
device = get_device_from_visible(self.data_device)
return paddle_move_data_to_device(batch, device)

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset
# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
"FastNLP does not support `IteratorDataset` now."
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist
return dataloader
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)

if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
)
dataloader.batch_sampler = batch_sampler
return dataloader
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader



+ 69
- 38
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -4,12 +4,14 @@ import struct
import random
import inspect
import numpy as np
from copy import deepcopy
from contextlib import ExitStack, closing
from enum import IntEnum
from typing import Dict, Optional, Union

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.core.samplers import RandomSampler
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger

@@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import nn
from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler
from paddle.io import DataLoader, BatchSampler, Dataset
from paddle.amp import auto_cast, GradScaler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -85,7 +87,7 @@ class ForwardState(IntEnum):
TEST = 2
PREDICT = 3

_MODE_PARAMETER = "_forward_state"
_MODE_PARAMETER = "forward_state"

class _FleetWrappingModel(Layer):
"""
@@ -151,24 +153,25 @@ class _FleetWrappingModel(Layer):

def forward(self, batch, **kwargs) -> Dict:

_forward_state = kwargs.pop(_MODE_PARAMETER)
forward_state = kwargs.pop(_MODE_PARAMETER)
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if _forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict):
if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif _forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict):
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif _forward_state == ForwardState.TEST:
if isinstance(batch, Dict):
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif _forward_state == ForwardState.PREDICT:
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
else:
raise NotImplementedError("You should direct a concrete mode.")
@@ -205,7 +208,6 @@ class DummyGradScaler:
def state_dict(self):
return {}


def _build_fp16_env(dummy=False):
if dummy:
auto_cast = ExitStack
@@ -255,61 +257,77 @@ def get_host_name_ip():
except:
return None

def get_device_from_visible(device: Union[str, int]):
def get_device_from_visible(device: Union[str, int], output_type=int):
"""
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。
:param devices:未转化的设备名
:param device: 未转化的设备名
:param output_type: 返回值的类型
:return: 转化后的设备id
"""
if output_type not in [int, str]:
raise ValueError("Parameter `output_type` should be one of these types: [int, str]")
if device == "cpu":
return device
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
idx = get_paddle_device_id(device)
if cuda_visible_devices is None or cuda_visible_devices == "":
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES
return idx
raise RuntimeError("This situation should not happen, please report us this bug.")
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visible_devices.split(",")[idx]
else:
idx = str(idx)
if user_visible_devices is None:
raise RuntimeError("This situation cannot happen, please report a bug to us.")
idx = user_visible_devices.split(",")[idx]

cuda_visible_devices_list = cuda_visible_devices.split(',')
assert idx in cuda_visible_devices_list, "Can't find "\
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\
% (idx, cuda_visible_devices)
if idx not in cuda_visible_devices_list:
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].")
res = cuda_visible_devices_list.index(idx)
return res
if output_type == int:
return res
else:
return f"gpu:{res}"

def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
# 拿到实例属性;
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"):
"""
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。
考虑了用户自己定制了 DataLoader 的情形。
"""
# 拿到非下划线开头的实例属性;
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}

# 拿到 dataloader '__init__' 函数的默认函数签名;
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型
init_params = dict(inspect.signature(dataloader.__init__).parameters)

# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
# 中寻找;
# 中寻找;VAR_KEYWORD 代表 **kwargs
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters))
del init_params["self"]

# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍;
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来
non_default_params = {name for name, p in init_params.items() if
name in instance_attrs and p.default != instance_attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_default_params.add("dataset")

# 收集不是默认值的参数和它的值
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1})

# persistent_workers 在类中的对应成员带有下划线,因此添加进来
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})

# POSITIONAL_OR_KEYWORD 代表一般的参数
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
# 也即它们没有在初始化函数和实例成员中同时出现
required_args = {
p.name
for p in init_params.values()
@@ -323,12 +341,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedBatchSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
)

# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
@@ -340,12 +355,28 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
"manually add the `DistributedBatchSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
)

return type(dataloader)(**reconstruct_args)

def replace_sampler(dataloader, new_sampler):
"""
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中
"""
new_batch_sampler = deepcopy(dataloader.batch_sampler)
new_batch_sampler.sampler = new_sampler
return replace_batch_sampler(dataloader, new_batch_sampler)

def optimizer_state_to_device(state, device):
new_state = {}
for name, param in state.items():
if isinstance(param, dict):
new_state[name] = optimizer_state_to_device(param, device)
elif isinstance(param, paddle.Tensor):
new_state[name] = paddle_to(param, device).clone()
else:
new_state[name] = param
return new_state

+ 17
- 26
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler

__all__ = [
'TorchDDPDriver'
@@ -167,6 +168,7 @@ class TorchDDPDriver(TorchDriver):
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后,
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉;
"""
# 在加入很多东西后,需要注意这里调用 super 函数的位置;
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs)

if isinstance(model, torch.nn.DataParallel):
@@ -202,8 +204,8 @@ class TorchDDPDriver(TorchDriver):
# 我们就直接将 model_device 置为 None;
self.model_device = None

def _running_fn_(batch, step_fn, signature_fn):
if isinstance(batch, Dict):
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return step_fn(batch)
@@ -214,7 +216,7 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
@@ -222,7 +224,7 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
@@ -230,14 +232,11 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._test_signature_fn = model.forward

# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
# if self.outside_ddp and self._data_device is None:
# raise RuntimeError("When you initialize your ddp out of our control, the parameter "
# "`data_device` can not be None.")
if isinstance(self._data_device, int):
if self._data_device < 0:
raise ValueError("Parameter `data_device` can not be smaller than 0.")
@@ -349,9 +348,9 @@ class TorchDDPDriver(TorchDriver):
**self._ddp_kwargs
)

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN})
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE})
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST})
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True

@@ -472,12 +471,11 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
if isinstance(args.sampler, ReproducibleSampler):
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
return dataloader
# trainer
elif dist == "dist":
@@ -526,18 +524,11 @@ class TorchDDPDriver(TorchDriver):
num_replicas=self.world_size,
rank=self.global_rank
)
return replace_sampler(dataloader, sampler)
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
return replace_batch_sampler(dataloader, batch_sampler)
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self):
return self.global_rank == 0



+ 33
- 32
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -3,28 +3,20 @@ import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8


from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
try:
from torch._C._distributed_c10d import ProcessGroupMPI
except ImportError:
_MPI_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupNCCL
except ImportError:
_NCCL_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
_GLOO_AVAILABLE = False
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass


from fastNLP.core.utils import apply_to_collection

@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
)


def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None):
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank 。

@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None)
>>> output
['foo', 12, {1: 2}]
"""
if group is None:
group = DEFAULT_TORCH_GROUP

if dist.distributed_c10d._rank_not_in_group(group):
return

@@ -193,7 +188,7 @@ def _to_device(tensor, device):
return tensor.contiguous().to(device)


def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。

@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
:param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
"""
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
return objs


def fastnlp_torch_broadcast_object(obj, src, device=None, group=None):
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
"""
将 src 上的 obj 对象广播到其它 rank 上。

@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None):
:param group:
:return:
"""
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
@@ -335,19 +333,21 @@ def all_gather_object(object_list, obj, group=None):
>>> output
['foo', 12, {1: 2}]
"""
if dist._rank_not_in_group(group):
if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
else:
current_device = torch.cuda.current_device()

input_tensor, local_size = _object_to_tensor(obj, device=current_device)

input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
@@ -378,3 +378,4 @@ def all_gather_object(object_list, obj, group=None):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
return object_list

+ 5
- 13
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -13,7 +13,7 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.log import logger


@@ -102,29 +102,21 @@ class TorchSingleDriver(TorchDriver):

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
@@ -148,7 +140,7 @@ class TorchSingleDriver(TorchDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = ReproducibleBatchSampler(
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 27
- 5
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler


class TorchDriver(Driver):
@@ -51,6 +51,9 @@ class TorchDriver(Driver):
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数;
self.non_blocking = kwargs.get("torch_non_blocking", True)

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def zero_grad(self, set_to_none: bool = False):
for optimizer in self.optimizers:
self._clear_grad(optimizer, set_to_none)
@@ -69,6 +72,14 @@ class TorchDriver(Driver):
p.grad.requires_grad_(False)
p.grad.zero_()

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:
@@ -102,7 +113,7 @@ class TorchDriver(Driver):
if mode == "validate":
if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"):
logger.warning(
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
@@ -191,9 +202,20 @@ class TorchDriver(Driver):
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
@@ -252,7 +274,7 @@ class TorchDriver(Driver):
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = ReproducibleBatchSampler(
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last


+ 15
- 15
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -8,6 +8,7 @@ import numpy as np
import inspect

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import re_instantiate_sampler

if _NEED_IMPORT_TORCH:
import torch
@@ -140,24 +141,25 @@ class _DDPWrappingModel(Module):
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
"""

_forward_state = kwargs.pop(_MODE_PARAMETER)
forward_state = kwargs.pop(_MODE_PARAMETER)
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if _forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict):
if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif _forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict):
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif _forward_state == ForwardState.TEST:
if isinstance(batch, Dict):
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif _forward_state == ForwardState.PREDICT:
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
else:
raise NotImplementedError("You should direct a concrete mode.")
@@ -294,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler):
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
)

return type(dataloader)(**reconstruct_args)


@@ -306,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler(
"""
batch_sampler = getattr(dataloader, "batch_sampler")
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler):
batch_sampler = re_instantiate_sampler(batch_sampler)

return {
"sampler": None,
@@ -342,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler):
params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler
return type(dataloader)(**params)
# TODO 这里是否可以auto_param_call一下
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()},
# signature_fn=type(dataloader).__init__)


def optimizer_state_to_device(state, device):


+ 19
- 3
fastNLP/core/log/logger.py View File

@@ -51,6 +51,7 @@ class LoggerSingleton(type):
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
def __init__(self, name):
super().__init__(name)
self._warning_msgs = set()

def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False,
mode: str = "w"):
@@ -108,10 +109,25 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)

def warning_once(self, msg, *args, **kwargs):
"""
通过 warning 内容只会 warning 一次

:param msg:
:param args:
:param kwargs:
:return:
"""
if msg not in self._warning_msgs:
if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def warn(self, msg, *args, **kwargs):
warnings.warn("The 'warn' method is deprecated, "
"use 'warning' instead", DeprecationWarning, 2)
self.warning(msg, *args, **kwargs)
if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)

def error(self, msg, *args, **kwargs):
"""


+ 2
- 3
fastNLP/core/metrics/accuracy.py View File

@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask

class Accuracy(Metric):

def __init__(self, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = True):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
@@ -64,7 +63,7 @@ class Accuracy(Metric):
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")

else:
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or "
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.")

if masks is not None:


+ 3
- 3
fastNLP/core/samplers/__init__.py View File

@@ -23,14 +23,14 @@ __all__ = [
"BucketedBatchSampler",
"ReproducibleBatchSampler",

"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
"re_instantiate_sampler"
]

from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler


+ 33
- 0
fastNLP/core/samplers/conversion_utils.py View File

@@ -0,0 +1,33 @@
from fastNLP.core.samplers import re_instantiate_sampler
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \
SortedSampler
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \
UnrepeatedSequentialSampler, UnrepeatedSortedSampler


def conversion_between_reproducible_and_unrepeated_sampler(sampler):
"""
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler,

:param sampler:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")

+ 60
- 43
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -4,16 +4,18 @@ __all__ = [
]

import math
from array import array
from copy import deepcopy
from typing import Dict, Union, List
from itertools import chain
import os

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger
from .utils import create_array, NumConsumedSamplesArray
from abc import abstractmethod
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE


class ReproducibleBatchSampler:
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler:

@abstractmethod
def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.")

@abstractmethod
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.batch_size = batch_size
self.drop_last = drop_last

self.data_idx = kwargs.get("data_idx", 0)
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0)

self.index_list = kwargs.get("index_list", self._iterate_sampler())
self.need_reinitialize = kwargs.get("need_reinitialize", False)
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler):
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况;
else:
_index_lst.append(idx)
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295;
if len(_index_lst) > 4294967295:
# 注意 self.index_list 内存放的是全部数据的 index;
# unsigned long
_index_lst = array("L", _index_lst)
else:
# unsigned int
_index_lst = array("I", _index_lst)
_index_lst = create_array(len(_index_lst), _index_lst)
return _index_lst

def __iter__(self):
if self.need_reinitialize:
self.index_list = self._iterate_sampler()
self.data_idx = 0
self.num_consumed_samples = 0
else:
self.need_reinitialize = True

batch = []
if self.data_idx:
index_list = self.index_list[self.data_idx:]
if self.num_consumed_samples:
index_list = self.index_list[self.num_consumed_samples:]
else:
index_list = self.index_list

# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
for idx in index_list:
batch.append(idx)
self.data_idx += 1
if len(batch) == self.batch_size:
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ]
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
self.num_consumed_samples += len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
# 需要重置防止边界条件问题
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def __len__(self) -> int:
if self.drop_last:
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler):
return (len(self.index_list) + self.batch_size - 1) // self.batch_size

def state_dict(self) -> Dict:
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__}
states = {
"index_list": deepcopy(self.index_list),
"num_consumed_samples": self.num_consumed_samples,
'sampler_type': self.__class__.__name__
}
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states

def load_state_dict(self, states: Dict):
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
@@ -128,11 +147,11 @@ class RandomBatchSampler(ReproducibleBatchSampler):
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \
"record and current dataset."
self.index_list = _index_list
self.data_idx = states["data_idx"]
self.num_consumed_samples = states["num_consumed_samples"]
self.need_reinitialize = False

def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")

def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size


class BucketedBatchSampler(ReproducibleBatchSampler):
@@ -166,8 +185,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__()
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的


self.batch_size = batch_size
self.num_batch_per_bucket = num_batch_per_bucket
self.shuffle = shuffle
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.rank = rank
self.pad = pad

num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \
else len(self.dataset)
if self.drop_last:
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \
"than the number of replicates multiplied " \
"with batch_size when drop_last=True."
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \
# else len(self.dataset)
#
# if self.drop_last:
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \
# "than the number of replicates multiplied " \
# "with batch_size when drop_last=True."

return self

@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))

def __len__(self):
def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据

@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')
self.old_batch_size = self.batch_size
self.old_num_batch_per_bucket = self.num_batch_per_bucket
self.old_num_replicas = self.num_replicas
@@ -356,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
batch_indices = list(batch_indices[:-1])
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。
batches = (np.array(batches)[batch_indices]).tolist()
batches = (np.array(batches, dtype=object)[batch_indices]).tolist()
if last_batches:
batches = batches + last_batches
return batches
@@ -365,21 +387,16 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {
'seed': self.seed,
'epoch': self.epoch,
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}

return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."



+ 44
- 28
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -1,16 +1,21 @@
__all__ = [
'ReproducibleSampler',
'RandomSampler',
"SortedSampler",
"SequentialSampler"
]

from typing import Dict, List, Union
import math
import os

import numpy as np

from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE
from .utils import NumConsumedSamplesArray

__all__ = [
'ReproducibleSampler',
'RandomSampler',
"SortedSampler",
"SequentialSampler"
]


class ReproducibleSampler:
@@ -30,6 +35,13 @@ class ReproducibleSampler:
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.")

def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.")

def load_state_dict(self, states):
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler):
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler):
return indices

def state_dict(self) -> Dict:
states = {
'seed': self.seed,
'epoch': self.epoch,
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'shuffle': self.shuffle
}
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler):
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler):
return list(range(len(self.dataset)))

def state_dict(self) -> Dict:
states = {
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
}
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

@@ -295,8 +307,8 @@ class SortedSampler(SequentialSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:
@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


+ 2
- 2
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:


+ 53
- 30
fastNLP/core/samplers/utils.py View File

@@ -1,42 +1,65 @@
__all__ = [
're_instantiate_sampler',
'conversion_between_reproducible_and_unrepeated_sampler'
're_instantiate_sampler'
]
from array import array
from typing import Sequence
from collections import deque

from fastNLP.core.samplers.unrepeated_sampler import *
from fastNLP.core.samplers.reproducible_sampler import *

def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler)
if new_sampler_class is not None:
return new_sampler_class(**all_attributes)
return type(sampler)(**all_attributes)

def conversion_between_reproducible_and_unrepeated_sampler(sampler):

def create_array(length, fill_value) -> array:
"""
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler,
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I'

:param sampler:
:param length:
:param fill_value:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
if not isinstance(fill_value, Sequence):
fill_value = [fill_value]*length

if length > 4294967295:
_index_lst = array("L", fill_value)
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")
_index_lst = array("I", fill_value)
return _index_lst


def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler)
if new_sampler_class is not None:
return new_sampler_class(**all_attributes)
return type(sampler)(**all_attributes)
class NumConsumedSamplesArray:
def __init__(self, buffer_size=2000, num_consumed_samples=0):
"""
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
ex:
array = NumConsumedSamplesArray(buffer_size=3)
for i in range(10):
array.push(i)

array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9]

:param buffer_size: 报错多少个历史。
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。
"""
self.count = 0
self.deque = deque(maxlen=buffer_size)
if num_consumed_samples is not None:
self.push(num_consumed_samples)
self.buffer_size = buffer_size

def __getitem__(self, item):
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0
return 0
assert isinstance(item, int), "Only int index allowed."
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index."
index = len(self.deque) - (self.count - item)
return self.deque[index]

def push(self, num_consumed_samples):
self.deque.append(num_consumed_samples)
self.count += 1

+ 1
- 2
fastNLP/core/utils/__init__.py View File

@@ -13,7 +13,6 @@ __all__ = [
'torch_paddle_move_data_to_device',
'torch_move_data_to_device',
'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call',
'check_user_specific_params',
'dataclass_to_dict',
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi
from .rich_progress import f_rich_progress
from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir



+ 4
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]):
device = device.lower()
if device == "cpu":
raise ValueError("Cannot get device id from `cpu`.")
elif device == "gpu":
return 0

match_res = re.match(r"gpu:\d+", device)
if not match_res:
raise ValueError(
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'"
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', "
f"not '{device}'"
)
device_id = device.split(':', 1)[1]
device_id = int(device_id)


+ 77
- 2
fastNLP/core/utils/rich_progress.py View File

@@ -6,7 +6,7 @@
import sys
from typing import Any, Union, Optional

from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn

__all__ = [
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton):
if task_id in self._tasks:
super().stop_task(task_id)
super().remove_task(task_id)
self.refresh() # 使得bar不残留

def start(self) -> None:
super().start()
self.console.show_cursor(show=True)

def update(
self,
task_id: TaskID,
*,
total: Optional[float] = None,
completed: Optional[float] = None,
advance: Optional[float] = None,
description: Optional[str] = None,
visible: Optional[bool] = None,
refresh: bool = False,
**fields: Any,
) -> None:
"""Update information associated with a task.

Args:
task_id (TaskID): Task id (returned by add_task).
total (float, optional): Updates task.total if not None.
completed (float, optional): Updates task.completed if not None.
advance (float, optional): Add a value to task.completed if not None.
description (str, optional): Change task description if not None.
visible (bool, optional): Set visible flag if not None.
refresh (bool): Force a refresh of progress information. Default is False.
**fields (Any): Additional data fields required for rendering.
"""
with self._lock:
task = self._tasks[task_id]
completed_start = task.completed

if total is not None and total != task.total:
task.total = total
task._reset()
if advance is not None:
task.completed += advance
if completed is not None:
task.completed = completed
if description is not None:
task.description = description
if visible is not None:
task.visible = visible
task.fields.update(fields)
update_completed = task.completed - completed_start

current_time = self.get_time()
old_sample_time = current_time - self.speed_estimate_period
_progress = task._progress

popleft = _progress.popleft
# 这里修改为至少保留一个,防止超长时间的迭代影响判断
while len(_progress)>1 and _progress[0].timestamp < old_sample_time:
popleft()
if update_completed > 0:
_progress.append(ProgressSample(current_time, update_completed))
if task.completed >= task.total and task.finished_time is None:
task.finished_time = task.elapsed

if refresh:
self.refresh()


class SpeedColumn(ProgressColumn):
"""
显示 task 的速度。

"""
def render(self, task: "Task"):
speed = task.speed
if speed is None:
return Text('-- it./s', style='progress.data.speed')
if speed > 0.1:
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed')
else:
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')


if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(),
SpeedColumn(),
TimeElapsedColumn(),
"/",
TimeRemainingColumn(),
TextColumn("{task.fields[post_desc]}", justify="right"),
transient=True,
disable=False,
speed_estimate_period=1
speed_estimate_period=30
)
else:
f_rich_progress = DummyFRichProgress()


+ 4
- 2
fastNLP/core/utils/torch_utils.py View File

@@ -1,9 +1,11 @@
from abc import ABC
from typing import Any, Union, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8
DEFAULT_TORCH_GROUP = None
if _NEED_IMPORT_TORCH:
import torch
if not _TORCH_GREATER_EQUAL_1_8:
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD

__all__ = [
'torch_move_data_to_device'


+ 117
- 37
fastNLP/core/utils/utils.py View File

@@ -1,3 +1,4 @@
import functools
import inspect
from inspect import Parameter
import dataclasses
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_GLOBAL_RANK



__all__ = [
'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call',
'check_user_specific_params',
'dataclass_to_dict',
@@ -44,48 +43,23 @@ __all__ = [
]





def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字;

:param fn: 需要查询的函数;

:return: 一个列表,其中的元素则是查询函数的参数的字符串名字;
"""
return list(inspect.signature(fn).parameters)


def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool:
r"""
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数;
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可;

:param fn: 传入的用以代替 Loop 中 'step' 函数的函数;
:param param_num: 检测的函数的应当的没有默认值的参数的个数;

:return: bool,表示传入的 `batch_step_fn` 是否正确;
"""

if fn is None:
return True
if not callable(fn):
return False
else:
params = inspect.signature(fn).parameters
not_default_params = {}
for _name, _param in params.items():
if _param.default == Parameter.empty:
not_default_params[_name] = _param
return len(not_default_params) == param_num


def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r"""
1.该函数用来提供给用户根据字符串匹配从而实现自动计算;
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。

1.该函数用来提供给用户根据字符串匹配从而实现自动调用;
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来;
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
@@ -113,6 +87,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240
"""

if signature_fn is not None:
if not callable(signature_fn):
raise ValueError(f"Parameter `signature_fn` should be `Callable`.")
@@ -122,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
_kwargs = None
for _name, _param in _need_params.items():
if _param.kind == Parameter.VAR_POSITIONAL:
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.")
if _param.kind == Parameter.VAR_KEYWORD:
_kwargs = (_name, _param)

@@ -135,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
_default_params[_name] = _param.default

if mapping is not None:
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."

_has_params = {}
duplicate_names = []
for arg in args:
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type."
if not isinstance(arg, Dict):
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise TypeError(f"Exception happens when calling {fn_msg}. "
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.")
for _name, _value in arg.items():
if mapping is not None and _name in mapping:
_name = mapping[_name]
@@ -152,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
elif _name in _need_params and not (_has_params[_name] is _value):
duplicate_names.append(_name)
if duplicate_names:
raise ValueError(f"The following key present in several inputs:{duplicate_names}")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.")

# 将具有默认值但是没有被输入修改过的参数值传进去;
for _name, _value in _default_params.items():
@@ -161,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None

if len(_has_params)<len(_need_params):
miss_params = list(set(_need_params.keys()) - set(_has_params.keys()))
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
_provided_keys = _get_keys(args)
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} "
f"are not found in the input keys({_provided_keys}).")

return fn(**_has_params)


def _get_keys(args:List[Dict]) -> List[List[str]]:
"""
返回每个 dict 的 keys

:param args:
:return:
"""
_provided_keys = []
for arg in args:
_provided_keys.append(list(arg.keys()))
return _provided_keys


def _get_fun_msg(fn)->str:
"""
获取函数的基本信息,帮助报错。
ex:
print(_get_fun_msg(_get_fun_msg))
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py)

:param callable fn:
:return:
"""
if isinstance(fn, functools.partial):
return _get_fun_msg(fn.func)
try:
fn_name = fn.__qualname__ + str(inspect.signature(fn))
except:
fn_name = str(fn)
try:
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')'
except:
fp = ''
msg = f'`{fn_name}`' + fp
return msg


def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
"""
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会
进行报错。

:param fn: 需要检测的函数,可以是 method 或者 function 。
:param expected_params: 期待应该支持的参数。
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。
:return:
"""
if fn_name is not None:
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."

parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
if len(parameters)>0 and parameters[0].name == 'self':
parameters = parameters[1:] # 去掉self

no_var_param = True # 没有 * 这种参数
number_param_need_value = 0
for param in parameters:
if param.kind is param.VAR_POSITIONAL:
no_var_param = False
elif param.kind is param.VAR_KEYWORD:
no_var_param = False
else:
if param.default is param.empty:
number_param_need_value += 1

if len(parameters)<len(expected_params) and no_var_param:
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, "
f"but {len(expected_params)} parameters:{expected_params} will be provided.")

if number_param_need_value>len(expected_params):
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only"
f" {len(expected_params)} parameters:{expected_params} will be provided.")


def check_user_specific_params(user_params: Dict, fn: Callable):
"""
该函数使用用户的输入来对指定函数的参数进行赋值;
@@ -184,7 +244,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable):
return user_params


def dataclass_to_dict(data: "dataclass") -> Dict:
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:
if not is_dataclass(data):
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.")
_dict = dict()
@@ -591,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
wait_to_success(path.exists)


def get_class_that_defined_method(method):
"""
给定一个method,返回这个 method 的 class 的对象

:param method:
:return:
"""
if isinstance(method, functools.partial):
return get_class_that_defined_method(method.func)
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)):
for cls in inspect.getmro(method.__self__.__class__):
if method.__name__ in cls.__dict__:
return cls
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing
if inspect.isfunction(method):
cls = getattr(inspect.getmodule(method),
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
None)
if isinstance(cls, type):
return cls
return getattr(method, '__objclass__', None) # handle special descriptor objects

+ 3
- 2
fastNLP/envs/__init__.py View File

@@ -6,7 +6,8 @@ __all__ = [
'is_cur_env_distributed',
'get_global_rank',
'rank_zero_call',
'all_rank_call'
'all_rank_call',
'get_gpu_count'
]


@@ -14,5 +15,5 @@ from .env import *
from .set_env_on_import import set_env_on_import
from .set_backend import dump_fastnlp_backend
from .imports import *
from .utils import _module_available
from .utils import _module_available, get_gpu_count
from .distributed import *

+ 2
- 0
fastNLP/envs/env.py View File

@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK'
# todo 注释
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"

# fastNLP 中初始化deque的默认大小
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'

# todo 注释 直接使用的变量
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar"


+ 11
- 8
fastNLP/envs/set_backend.py View File

@@ -5,13 +5,13 @@
import os
import json
import sys
import subprocess
from collections import defaultdict


from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED
from fastNLP.envs.imports import SUPPORT_BACKENDS
from fastNLP.envs.utils import _module_available

from fastNLP.envs.utils import _module_available, get_gpu_count

def _set_backend():
"""
@@ -56,17 +56,18 @@ def _set_backend():
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',')
if user_visible_devices is not None and user_visible_devices != "":
if user_visible_devices is not None:
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中
# 我们需要从中找到真正使用的设备编号
user_visible_devices = user_visible_devices.split(",")
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus])
else:
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见
os.environ[USER_CUDA_VISIBLE_DEVICES] = ""
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0]
# 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list(
range(get_gpu_count())
)))
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus)
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))])
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))])
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
@@ -78,7 +79,9 @@ def _set_backend():
else:
# 没有设置的话限制在单卡上,防止多进程时占用别的卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ[USER_CUDA_VISIBLE_DEVICES] = ''
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list(
range(get_gpu_count())
)))

elif backend == 'jittor':
assert _module_available(backend), f"You must have {backend} available to use {backend} backend."


+ 1
- 2
fastNLP/envs/set_env_on_import.py View File

@@ -36,8 +36,7 @@ def set_env_on_import_torch():

# TODO paddle may need set this
def set_env_on_import_paddle():
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \
and "PADDLE_RANK_IN_NODE" in os.environ:
# 检测到了分布式环境的环境变量
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"]


+ 13
- 0
fastNLP/envs/utils.py View File

@@ -3,6 +3,7 @@ from typing import Callable
import importlib
from pkg_resources import DistributionNotFound
from packaging.version import Version
import subprocess
import pkg_resources


@@ -46,3 +47,15 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))

def get_gpu_count():
"""
利用命令行获取gpu数目的函数
:return: gpu数目,如果没有显卡设备则为-1
"""
try:
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv'])
# 经分割后还要除去头部和尾部的换行符
return len(lines.split(b"\n")) - 2
except:
return -1

+ 5
- 5
fastNLP/io/data_bundle.py View File

@@ -251,10 +251,10 @@ class DataBundle:
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法

.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
@@ -285,7 +285,7 @@ class DataBundle:
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法

对DataBundle中所有的dataset使用apply方法

@@ -309,10 +309,10 @@ class DataBundle:
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法

.. note::
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe):

def process_from_file(self, paths) -> DataBundle:
r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`

:param paths:
:return: DataBundle


+ 1
- 1
fastNLP/io/pipe/construct_graph.py View File

@@ -164,7 +164,7 @@ class GraphBuilderBase:

def build_graph_from_file(self, path: str):
r"""
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`

:param path:
:return: scipy_sparse_matrix


+ 1
- 1
fastNLP/io/pipe/pipe.py View File

@@ -33,7 +33,7 @@ class Pipe:

def process_from_file(self, paths: str) -> DataBundle:
r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`

:param str paths:
:return: DataBundle


+ 1
- 1
tests/core/callbacks/test_callback_events.py View File

@@ -1,7 +1,7 @@
import pytest
from functools import reduce

from fastNLP.core.callbacks.callback_events import Filter
from fastNLP.core.callbacks.callback_events import Events, Filter


class TestFilter:


+ 2
- 2
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -10,7 +10,7 @@ import re

from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK

from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
@@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2(

from fastNLP.core.callbacks.callback_events import Events

@Trainer.on(Events.ON_TRAIN_EPOCH_END)
@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError


+ 93
- 0
tests/core/controllers/test_trainer_fleet.py View File

@@ -0,0 +1,93 @@
"""
这个文件测试用户以python -m paddle.distributed.launch 启动的情况
看看有没有用pytest执行的机会
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
"""
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
import sys
sys.path.append("../../../")

from dataclasses import dataclass

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.callbacks import Callback

import paddle
from paddle.optimizer import Adam
from paddle.io import DataLoader

from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback

@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 10

batch_size: int = 32
shuffle: bool = True
validate_every = -1

def test_trainer_fleet(
driver,
device,
callbacks,
n_epochs,
):
model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainFleetConfig.num_labels,
feature_dimension=MNISTTrainFleetConfig.feature_dimension
)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)

train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
train_dataloader = train_dataloader
validate_dataloaders = val_dataloader
validate_every = MNISTTrainFleetConfig.validate_every
metrics = {"acc": Accuracy()}
trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,

n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="logs",
)
trainer.run()

if __name__ == "__main__":
driver = "fleet"
device = [0,2,3]
# driver = "paddle"
# device = 2
callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5),
]
test_trainer_fleet(
driver=driver,
device=device,
callbacks=callbacks,
n_epochs=5,
)

+ 98
- 0
tests/core/controllers/test_trainer_fleet_outside.py View File

@@ -0,0 +1,98 @@
"""
这个文件测试用户以python -m paddle.distributed.launch 启动的情况
并且自己初始化了 fleet
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
"""
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
import sys
sys.path.append("../../../")

from dataclasses import dataclass

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.callbacks import Callback

import paddle
from paddle.optimizer import Adam
from paddle.io import DataLoader
import paddle.distributed.fleet as fleet

from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback

@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 10

batch_size: int = 32
shuffle: bool = True
validate_every = -1

def test_trainer_fleet(
driver,
device,
callbacks,
n_epochs,
):
fleet.init(is_collective=True)

model = PaddleNormalModel_Classification_2(
num_labels=MNISTTrainFleetConfig.num_labels,
feature_dimension=MNISTTrainFleetConfig.feature_dimension,
)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)

model = fleet.distributed_model(model)
optimizers = fleet.distributed_optimizer(optimizers)

train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
train_dataloader = train_dataloader
validate_dataloaders = val_dataloader
validate_every = MNISTTrainFleetConfig.validate_every
metrics = {"acc": Accuracy()}
trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,

n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="logs",
data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}"
)
trainer.run()

if __name__ == "__main__":
driver = "fleet"
device = [0,2,3]
callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5),
]
test_trainer_fleet(
driver=driver,
device=device,
callbacks=callbacks,
n_epochs=30,
)

+ 25
- 0
tests/core/controllers/test_trainer_other_things.py View File

@@ -0,0 +1,25 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from tests.helpers.utils import magic_argv_env_context


@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10))
def fn1(trainer):
pass

@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn2(trainer, batch, indices):
pass

with pytest.raises(AssertionError):
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn3(trainer, batch):
pass





+ 43
- 3
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -98,14 +98,16 @@ def model_and_optimizers(request):


# 测试一下普通的情况;
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]])
@pytest.mark.parametrize("validate_every", [-3])
@magic_argv_env_context
def test_trainer_torch_with_evaluator(
model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
validate_every,
n_epochs=10,
):
trainer = Trainer(
@@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator(
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
validate_every=validate_every,

n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="all"

)

trainer.run()
@@ -143,7 +145,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
accumulation_steps,
n_epochs=6,
):
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)]
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
@@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
dist.destroy_process_group()


@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_validate_every(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=6,
):

def validate_every(trainer):
if trainer.global_forward_batches % 10 == 0:
print(trainer)
print("\nfastNLP test validate every.\n")
print(trainer.global_forward_batches)
return True

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
output_from_new_proc="all",
validate_every=validate_every
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()




+ 40
- 3
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -10,7 +10,7 @@ from typing import Any
from pathlib import Path

from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
@@ -70,7 +70,7 @@ def model_and_optimizers(request):
trainer_params.output_mapping = None

# elif request.param == 1:
# model =

return trainer_params

@@ -254,7 +254,7 @@ def test_trainer_on_exception(
):
from fastNLP.core.callbacks.callback_events import Events

@Trainer.on(Events.ON_TRAIN_EPOCH_END)
@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer):
if trainer.driver.get_local_rank() == cur_rank:
raise NotImplementedError
@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version):
subprocess.check_call(command)


@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])])
@magic_argv_env_context
def test_torch_wo_auto_param_call(
driver,
device,
n_epochs=10,
):

model = TorchNormalModel_Classification_3(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension
)
optimizers = SGD(model.parameters(), lr=0.001)
dataset = TorchNormalDataset_Classification(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension,
each_label_data=NormalClassificationTrainTorchConfig.each_label_data,
seed=NormalClassificationTrainTorchConfig.seed
)
train_dataloader = DataLoader(
dataset=dataset,
batch_size=NormalClassificationTrainTorchConfig.batch_size,
shuffle=True
)

trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
n_epochs=n_epochs,

model_wo_auto_param_call=True,
output_from_new_proc="all"
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()





+ 58
- 38
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -1,83 +1,103 @@
import os
import pytest

from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle
os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from fastNLP.envs import get_gpu_count
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context

import paddle

def test_incorrect_driver():

model = PaddleNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError):
driver = initialize_paddle_driver("torch")
driver = initialize_paddle_driver("torch", 0, model)

@pytest.mark.parametrize(
"device",
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
["cpu", "gpu:0", 0, [1]]
)
def test_get_single_device(device):
@pytest.mark.parametrize(
"driver",
["paddle"]
)
def test_get_single_device(driver, device):
"""
测试正常情况下初始化PaddleSingleDriver的情况
"""

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

model = PaddleNormalModel_Classification_1(2, 100)
driver = initialize_paddle_driver(driver, device, model)
assert isinstance(driver, PaddleSingleDriver)

@pytest.mark.parametrize(
"device",
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
[0, 1]
)
def test_get_single_device_with_visiblde_devices(device):
@pytest.mark.parametrize(
"driver",
["fleet"]
)
@magic_argv_env_context
def test_get_fleet_2(driver, device):
"""
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况
测试 fleet 多卡的初始化情况
"""
# TODO

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)
model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)

assert isinstance(driver, PaddleSingleDriver)
assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.parametrize(
"device",
[[1, 2, 3]]
[[0, 2, 3], -1]
)
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
def test_get_fleet(device):
@magic_argv_env_context
def test_get_fleet(driver, device):
"""
测试 fleet 多卡的初始化情况
"""

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)
model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)

assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.parametrize(
"device",
[[1,2,3]]
("driver", "device"),
[("fleet", "cpu")]
)
def test_get_fleet(device):
@magic_argv_env_context
def test_get_fleet_cpu(driver, device):
"""
测试 launch 启动 fleet 多卡的初始化情况
测试试图在 cpu 上初始化分布式训练的情况
"""
# TODO

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleFleetDriver)
model = PaddleNormalModel_Classification_1(64, 10)
with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)

def test_device_out_of_range(device):
@pytest.mark.parametrize(
"device",
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1]
)
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
@magic_argv_env_context
def test_device_out_of_range(driver, device):
"""
测试传入的device超过范围的情况
"""
pass
model = PaddleNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)

+ 0
- 262
tests/core/drivers/paddle_driver/test_paddle_driver.py View File

@@ -1,262 +0,0 @@
import unittest

import torch

from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver
import paddle
from paddle.io import Dataset, DataLoader

class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()

self.fc1 = paddle.nn.Linear(784, 64)
self.fc2 = paddle.nn.Linear(64, 32)
self.fc3 = paddle.nn.Linear(32, 10)
self.fc4 = paddle.nn.Linear(10, 10)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)

return x


class PaddleDataset(Dataset):
def __init__(self):
super(PaddleDataset, self).__init__()
self.items = [paddle.rand((3, 4)) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]


class TorchNet(torch.nn.Module):
def __init__(self):
super(TorchNet, self).__init__()

self.torch_fc1 = torch.nn.Linear(10, 10)
self.torch_softmax = torch.nn.Softmax(0)
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
self.torch_tensor = torch.ones(3, 3)
self.torch_param = torch.nn.Parameter(torch.ones(4, 4))


class TorchDataset(torch.utils.data.Dataset):
def __init__(self):
super(TorchDataset, self).__init__()
self.items = [torch.ones(3, 4) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]


class PaddleDriverTestCase(unittest.TestCase):
"""
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试
"""

def setUp(self):
model = Net()
self.driver = PaddleDriver(model)

def test_check_single_optimizer_legacy(self):
"""
测试传入单个optimizer时的表现
"""
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)

self.driver.set_optimizers(optimizer)

optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legacy(self):
"""
测试传入optimizer list的表现
"""
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]

self.driver.set_optimizers(optimizers)

optimizers += [
torch.optim.Adam(TorchNet().parameters(), 0.01)
]

with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legacy_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleDataset())
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legacy_in_test(self):
"""
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())}
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleDataset())
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""
测试tensor_to_numeric函数
"""
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleDriver.tensor_to_numeric(tensor)
self.assertEqual(res, 3)

tensor = paddle.rand((3, 4))
res = PaddleDriver.tensor_to_numeric(tensor)
self.assertListEqual(res, tensor.tolist())

# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleDriver.tensor_to_numeric(tensor_list)
self.assertTrue(res, list)
tensor_list = [t.tolist() for t in tensor_list]
self.assertListEqual(res, tensor_list)

# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleDriver.tensor_to_numeric(tensor_tuple)
self.assertTrue(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
self.assertTupleEqual(res, tensor_tuple)

# 张量dict
tensor_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = PaddleDriver.tensor_to_numeric(tensor_dict)
self.assertIsInstance(res, dict)
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist())
self.assertIsInstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
self.assertListEqual(r, d.tolist())
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
self.assertListEqual(r, d.tolist())
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist())

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
self.assertTrue(self.driver.model.training)
self.driver.set_model_mode("eval")
self.assertFalse(self.driver.model.training)
# 应该报错
with self.assertRaises(AssertionError) as cm:
self.driver.set_model_mode("test")

def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
PaddleDriver.move_model_to_device(self.driver.model, "cpu")
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place())

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0")
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place())
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0)

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
PaddleDriver.worker_init_function(0)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
self.driver.set_sampler_epoch(dataloader, 0)

def test_get_dataloader_args(self):
"""
测试get_dataloader_args
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
res = PaddleDriver.get_dataloader_args(dataloader)

+ 526
- 58
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,19 +1,19 @@
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path

from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from fastNLP.core import synchronize_safe_rm

set_env_on_import_paddle()
set_env("paddle")
import paddle
from paddle.io import DataLoader, BatchSampler

from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import RandomBatchSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset
from fastNLP.core import synchronize_safe_rm
import torch


############################################################################
@@ -26,38 +26,116 @@ def generate_random_driver(features, labels):
"""
生成driver
"""
model = PaddleNormalModel_Classification(labels, features)
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8)
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

def test_save_and_load(prepare_test_save_load):
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.pdparams"
driver1, driver2, dataloader = prepare_test_save_load
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)

# TODO 断点重训完善后在这里迭代几次

sampler_states = dataloader.batch_sampler.state_dict()
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

driver1.save(path, {})
driver2.load(path)
assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally:
synchronize_safe_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)

# TODO 断点重训完善后在这里迭代几次

sampler_states = dataloader.batch_sampler.sampler.state_dict()
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]

assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"]
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally:
synchronize_safe_rm(path)

@@ -67,13 +145,14 @@ def test_save_and_load_state_dict(prepare_test_save_load):
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "model.pdparams"
path = "dict"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path)
driver2.model.load_dict(driver2.load_model(path))
driver2.load_model(path)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

@@ -87,19 +166,22 @@ def test_save_and_load_whole_model(prepare_test_save_load):
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "model.pdparams"
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]])
driver2.model = driver2.load_model(path, load_dict=False)
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict=False)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path)
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")


class TestSingleDeviceFunction:
@@ -109,8 +191,8 @@ class TestSingleDeviceFunction:

@classmethod
def setup_class(cls):
model = PaddleNormalModel_Classification(10, 784)
cls.driver = PaddleSingleDriver(model)
model = PaddleNormalModel_Classification_1(10, 784)
cls.driver = PaddleSingleDriver(model, device="cpu")

def test_unwrap_model(self):
"""
@@ -125,22 +207,6 @@ class TestSingleDeviceFunction:
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")

def test_get_model_device_cpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu")
device = self.driver.get_model_device()
assert device == "cpu", device

def test_get_model_device_gpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0")
device = self.driver.get_model_device()
assert device == "gpu:0", device

def test_is_distributed(self):
assert self.driver.is_distributed() == False

@@ -151,18 +217,420 @@ class TestSingleDeviceFunction:
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))

@pytest.mark.parametrize(
"dist_sampler",
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_repalce_sampler(self, dist_sampler, reproducible):

class TestSetDistReproDataloder:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
"""
测试replace_sampler函数
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False)

assert replaced_loader is self.dataloader

def test_set_dist_repro_dataloader_with_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
"""
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == self.dataloader.drop_last

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
"""
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dist = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。
already_seen_idx = set()
for idx, batch in replaced_loader:
already_seen_idx.update(batch)
if idx >= 1:
break
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
print(sampler_states["data_idx"])

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
replaced_loader.batch_sampler.load_state_dict(sampler_states)
else:
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(replaced_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)

class TestPaddleDriverFunctions:
"""
使用 PaddleSingleDriver 测试基类的函数
"""

@classmethod
def setup_class(self):
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")

def test_check_single_optimizer_legality(self):
"""
测试传入单个optimizer时的表现
"""
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)

self.driver.set_optimizers(optimizer)

optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legality(self):
"""
测试传入optimizer list的表现
"""
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]

self.driver.set_optimizers(optimizers)

optimizers += [
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]

with pytest.raises(ValueError):
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legality_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legality_in_test(self):
"""
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""
测试tensor_to_numeric函数
"""
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == 3

tensor = paddle.rand((3, 4))
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == tensor.tolist()

# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleSingleDriver.tensor_to_numeric(tensor_list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
assert res == tensor_list

# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
assert res == tensor_tuple

# 张量dict
tensor_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = PaddleSingleDriver.tensor_to_numeric(tensor_dict)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
assert self.driver.model.training
self.driver.set_model_mode("eval")
assert not self.driver.model.training
# 应该报错
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")

def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu")
assert self.driver.model.linear1.weight.place.is_cpu_place()

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu")
assert self.driver.model.linear1.weight.place.is_gpu_place()
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
PaddleSingleDriver.worker_init_function(0)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last):
"""
测试正常情况下 get_dataloader_args 的表现
"""
dataloader = DataLoader(
PaddleNormalDataset(),
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 batch_sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
)
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last)
batch_sampler.sampler = RandomSampler(dataset, shuffle)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

+ 54
- 2
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -1,4 +1,56 @@
import unittest
import os
import pytest
os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers.paddle_driver.utils import (
get_device_from_visible,
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler

import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
from paddle.io import DataLoader, BatchSampler

from tests.helpers.datasets.paddle_data import PaddleNormalDataset

@pytest.mark.parametrize(
("user_visible_devices, cuda_visible_devices, device, output_type, correct"),
(
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"),
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"),
("3,4,5,6", "3,5", 0, int, 0),
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"),
)
)
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
res = get_device_from_visible(device, output_type)
assert res == correct

def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)

replaced_loader = replace_batch_sampler(dataloader, batch_sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

def test_replace_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
sampler = RandomSampler(dataset)

replaced_loader = replace_sampler(dataloader, sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)

+ 24
- 24
tests/core/log/test_logger.py View File

@@ -6,13 +6,16 @@ import logging
import re

from fastNLP.envs.env import FASTNLP_LAUNCH_TIME
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
from fastNLP.core.log.logger import logger

from tests.helpers.utils import magic_argv_env_context, recover_logger


# 测试 TorchDDPDriver;
@magic_argv_env_context
def test_add_file_ddp_1():
@recover_logger
def test_add_file_ddp_1_torch():
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;

@@ -56,11 +59,11 @@ def test_add_file_ddp_1():
synchronize_safe_rm(filepath)
dist.barrier()
dist.destroy_process_group()
logger.removeHandler(handler)


@magic_argv_env_context
def test_add_file_ddp_2():
@recover_logger
def test_add_file_ddp_2_torch():
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
"""
@@ -103,14 +106,14 @@ def test_add_file_ddp_2():
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

dist.barrier()
dist.destroy_process_group()


@magic_argv_env_context
def test_add_file_ddp_3():
@recover_logger
def test_add_file_ddp_3_torch():
"""
path = None;

@@ -155,10 +158,10 @@ def test_add_file_ddp_3():
synchronize_safe_rm(file)
dist.barrier()
dist.destroy_process_group()
logger.removeHandler(handler)

@magic_argv_env_context
def test_add_file_ddp_4():
@recover_logger
def test_add_file_ddp_4_torch():
"""
测试 path 是文件夹;
"""
@@ -200,7 +203,6 @@ def test_add_file_ddp_4():
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

dist.barrier()
dist.destroy_process_group()
@@ -209,12 +211,11 @@ def test_add_file_ddp_4():
class TestLogger:
msg = 'some test log msg'

@recover_logger
def test_add_file_1(self):
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
"""
from fastNLP.core.log.logger import logger

path = Path(tempfile.mkdtemp())
try:
filepath = path.joinpath('log.txt')
@@ -225,14 +226,12 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

@recover_logger
def test_add_file_2(self):
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
"""
from fastNLP.core.log.logger import logger

origin_path = Path(tempfile.mkdtemp())

try:
@@ -245,14 +244,12 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(origin_path)
logger.removeHandler(handler)

@recover_logger
def test_add_file_3(self):
"""
测试 path 是 None;
"""
from fastNLP.core.log.logger import logger

handler = logger.add_file()
logger.info(self.msg)

@@ -264,14 +261,12 @@ class TestLogger:
line = ''.join([l for l in f])
assert self.msg in line
file.unlink()
logger.removeHandler(handler)

@recover_logger
def test_add_file_4(self):
"""
测试 path 是文件夹;
"""
from fastNLP.core.log.logger import logger

path = Path(tempfile.mkdtemp())
try:
handler = logger.add_file(path)
@@ -285,16 +280,21 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

@recover_logger
def test_stdout(self, capsys):
from fastNLP.core.log.logger import logger

handler = logger.set_stdout(stdout="raw")
logger.info(self.msg)
logger.debug('aabbc')
captured = capsys.readouterr()
assert "some test log msg\n" == captured.out

logger.removeHandler(handler)
@recover_logger
def test_warning_once(self, capsys):
logger.warning_once('#')
logger.warning_once('#')
logger.warning_once('@')
captured = capsys.readouterr()
assert captured.out.count('#') == 1
assert captured.out.count('@') == 1


+ 85
- 8
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -3,6 +3,7 @@ from array import array
import numpy as np
import pytest
from itertools import chain
from copy import deepcopy

from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler:
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "RandomBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replica', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replica*batch_size > num_samples:
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1
bucket_diff = batch_size * num_batch_per_bucket * num_replica
bucket_diff = batch_size * num_batch_per_bucket * num_replicas

for bs in zip(*samplers):
diff = max(chain(*bs)) - min(chain(*bs))
assert diff <= bucket_diff

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能
偏多

:return:
"""
batch_size = 6
num_batch_per_bucket = 10
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)

sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)

# 测试保存之后再次保存
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

+ 60
- 2
tests/core/samplers/test_reproducible_sampler.py View File

@@ -3,6 +3,7 @@ import pytest

from functools import partial
from itertools import chain
from copy import deepcopy

from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -180,6 +181,63 @@ class TestRandomSamplerYh:
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
# 测试在 sampler 多生成的时候,可以仍然可以恢复
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
for i in range(num_replicas):
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for idxes in zip(*samplers):
already_seen_set.update(idxes)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
sampler = RandomSampler(dataset, shuffle=shuffle)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.add(batch)
assert len(already_seen_set) == len(dataset)
# 测试保存之后再次保存
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for idx in sampler:
already_seen_set.add(idx)
count += 1
if count > 6:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for idx in sampler:
already_seen_set.add(idx)

assert len(already_seen_set)==len(dataset)


class TestRandomSampler:
# 测试单卡;
@@ -386,7 +444,7 @@ class TestSortedSampler:
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
@@ -540,7 +598,7 @@ class TestSequentialSampler:
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)


+ 3
- 3
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -25,7 +25,7 @@ class TestUnrepeatedSampler:
indexes = set(sampler)
assert indexes==set(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler:
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler:
indexes = list(sampler)
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)


+ 187
- 0
tests/core/utils/test_utils.py View File

@@ -0,0 +1,187 @@
from functools import partial

import pytest

from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg
from fastNLP.core.metrics import Metric



class TestAutoParamCall:
def test_basic(self):
def fn(x):
return x
x = {'x': 3, 'y': 4}
r = auto_param_call(fn, x)
assert r==3

xs = []
for i in range(10):
xs.append({f'x{i}': i})
def fn(x0, x1, x2, x3):
return x0 + x1 + x2 + x3
r = auto_param_call(fn, *xs)
assert r == 0 + 1+ 2+ 3

def fn(chongfu1, chongfu2, buChongFu):
pass
with pytest.raises(BaseException) as exc_info:
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2},
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})
assert 'The following key present in several inputs' in exc_info.value.args[0]
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0]

# 没用到不报错
def fn(chongfu1, buChongFu):
pass
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2},
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})

# 可以定制signature_fn
def fn1(**kwargs):
kwargs.pop('x')
kwargs.pop('y')
assert len(kwargs)==0
def fn(x, y):
pass
x = {'x': 3, 'y': 4}
r = auto_param_call(fn1, x, signature_fn=fn)

# 没提供的时候报错
def fn(meiti1, meiti2, tigong):
pass
with pytest.raises(BaseException) as exc_info:
auto_param_call(fn, {'tigong':1})
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0]

# 默认值替换
def fn(x, y=100):
return x + y
r = auto_param_call(fn, {'x': 10, 'y': 20})
assert r==30
assert auto_param_call(fn, {'x': 10, 'z': 20})==110

# 测试mapping的使用
def fn(x, y=100):
return x + y
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'})
assert r==30

# 测试不需要任何参数
def fn():
return 1
assert 1 == auto_param_call(fn, {'x':1})

# 测试调用类的方法没问题
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1})
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'})

def test_msg(self):
with pytest.raises(BaseException) as exc_info:
auto_param_call(self.call_this, {'x':1})
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(call_this_for_auto_param_call, {'x':1})
assert __file__ in exc_info.value.args[0]
assert 'call_this_for_auto_param_call' in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(self.call_this_two, {'x':1})
assert __file__ in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this)
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息

def call_this(self, x, y):
return x + y

def call_this_two(self, x, y, z=pytest, **kwargs):
return x + y

def test_metric_auto_param_call(self):
metric = AutoParamCallMetric()
with pytest.raises(BaseException):
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__)


class AutoParamCallMetric(Metric):
def update(self, x):
pass


def call_this_for_auto_param_call(x, y):
return x + y


class TestCheckNumberOfParameters:
def test_validate_every(self):
def validate_every(trainer):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

# 无默认值,多了报错
def validate_every(trainer, other):
pass
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer'])
assert "2 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0])

# 有默认值ok
def validate_every(trainer, other=1):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

# 参数多了
def validate_every(trainer):
pass
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other'])
assert "accepts 1 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0])

# 使用partial
def validate_every(trainer, other):
pass
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer'])
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other'])
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more'])
assert 'accepts 2 parameters' in exc_info.value.args[0]
print(exc_info.value.args[0])

# 如果存在 *args 或 *kwargs 不报错多的
def validate_every(trainer, *args):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more'])

def validate_every(trainer, **kwargs):
pass
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more'])

# class 的方法删掉self
class InnerClass:
def demo(self, x):
pass

def no_param(self):
pass

def param_kwargs(self, **kwargs):
pass

inner = InnerClass()
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more'])
assert 'accepts 1 parameters' in exc_info.value.args[0]

_check_valid_parameters_number(inner.demo, expected_params=['trainer'])


def test_get_fun_msg():
def demo(x):
pass

print(_get_fun_msg(_get_fun_msg))

+ 8
- 2
tests/helpers/callbacks/helper_callbacks.py View File

@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback):
def on_after_backward(self, trainer):
print("on_after_backward")

def on_before_optimizer_step(self, trainer, optimizers):
print("on_before_optimizer_step")
def on_before_optimizers_step(self, trainer, optimizers):
print("on_before_optimizers_step")

def on_after_optimizers_step(self, trainer, optimizers):
print("on_after_optimizers_step")

def on_before_zero_grad(self, trainer, optimizers):
print("on_before_zero_grad")

def on_after_zero_grad(self, trainer, optimizers):
print("on_after_zero_grad")

def on_validate_begin(self, trainer):
print("on_validate_begin")



+ 27
- 0
tests/helpers/models/torch_model.py View File

@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module):
x = torch.max(x, dim=-1)[1]
return {"preds": x, "target": y}


class TorchNormalModel_Classification_2(nn.Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y}


class TorchNormalModel_Classification_3(nn.Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数;
"""
def __init__(self, num_labels, feature_dimension):
super(TorchNormalModel_Classification_3, self).__init__()
self.num_labels = num_labels

self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=10, out_features=10)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=10, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, batch):
x = batch["x"]
y = batch["y"]
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
loss = self.loss_fn(x, y)
x = torch.max(x, dim=-1)[1]
return {"loss": loss, "preds": x, "target": y}




+ 14
- 17
tests/helpers/utils.py View File

@@ -2,34 +2,31 @@ import os
import sys
import __main__
from functools import wraps
import inspect
from inspect import ismethod
import functools
from copy import deepcopy
from io import StringIO
import time

import numpy as np

from fastNLP.core.utils.utils import get_class_that_defined_method
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.log import logger


def get_class_that_defined_method(meth):
if isinstance(meth, functools.partial):
return get_class_that_defined_method(meth.func)
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)):
for cls in inspect.getmro(meth.__self__.__class__):
if meth.__name__ in cls.__dict__:
return cls
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing
if inspect.isfunction(meth):
cls = getattr(inspect.getmodule(meth),
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
None)
if isinstance(cls, type):
return cls
return getattr(meth, '__objclass__', None) # handle special descriptor objects
def recover_logger(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# 保存logger的状态
handlers = [handler for handler in logger.handlers]
level = logger.level
res = fn(*args, **kwargs)
logger.handlers = handlers
logger.setLevel(level)
return res

return wrapper


def magic_argv_env_context(fn):


Loading…
Cancel
Save