Browse Source

fix conflict

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
98644d2d0b
37 changed files with 378 additions and 501 deletions
  1. +20
    -1
      fastNLP/core/callbacks/callback.py
  2. +50
    -20
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +9
    -5
      fastNLP/core/callbacks/early_stop_callback.py
  4. +3
    -2
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +8
    -6
      fastNLP/core/callbacks/progress_callback.py
  6. +18
    -34
      fastNLP/core/controllers/evaluator.py
  7. +2
    -4
      fastNLP/core/controllers/loops/train_batch_loop.py
  8. +61
    -40
      fastNLP/core/controllers/trainer.py
  9. +3
    -3
      fastNLP/core/controllers/utils/state.py
  10. +1
    -1
      fastNLP/core/controllers/utils/utils.py
  11. +29
    -100
      fastNLP/core/drivers/driver.py
  12. +7
    -7
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  13. +5
    -5
      fastNLP/core/drivers/jittor_driver/single_device.py
  14. +3
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  15. +8
    -8
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  16. +8
    -8
      fastNLP/core/drivers/paddle_driver/single_device.py
  17. +7
    -7
      fastNLP/core/drivers/paddle_driver/utils.py
  18. +40
    -61
      fastNLP/core/drivers/torch_driver/ddp.py
  19. +26
    -70
      fastNLP/core/drivers/torch_driver/single_device.py
  20. +2
    -18
      fastNLP/core/drivers/torch_driver/torch_driver.py
  21. +6
    -51
      fastNLP/core/drivers/torch_driver/utils.py
  22. +5
    -5
      fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py
  23. +3
    -3
      fastNLP/core/log/logger.py
  24. +8
    -8
      tests/core/callbacks/test_checkpoint_callback_torch.py
  25. +1
    -1
      tests/core/callbacks/test_load_best_model_callback_torch.py
  26. +1
    -1
      tests/core/controllers/_test_distributed_launch_torch_1.py
  27. +1
    -1
      tests/core/controllers/_test_distributed_launch_torch_2.py
  28. +1
    -1
      tests/core/controllers/test_trainer_event_trigger.py
  29. +2
    -2
      tests/core/controllers/test_trainer_fleet.py
  30. +2
    -2
      tests/core/controllers/test_trainer_fleet_outside.py
  31. +1
    -1
      tests/core/controllers/test_trainer_paddle.py
  32. +8
    -8
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  33. +5
    -5
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  34. +7
    -7
      tests/core/drivers/paddle_driver/test_single_device.py
  35. +1
    -0
      tests/core/utils/test_utils.py
  36. +15
    -1
      tests/helpers/datasets/normal_data.py
  37. +1
    -1
      tests/helpers/models/torch_model.py

+ 20
- 1
fastNLP/core/callbacks/callback.py View File

@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback):
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

+ 50
- 20
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
class CheckpointCallback(HasMonitorCallback):
def __init__(
self,
monitor,
monitor:Optional[Union[str, Callable]]=None,
save_folder: Optional[Union[str, Path]] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = True,
save_last: bool = False,
save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
larger_better: bool = True,
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback):
model_save_fn: Optional[Callable] = None,
**kwargs,
):
"""
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
super().__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=save_topk is not None)
if save_folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
save_folder = Path.cwd()
save_folder = Path(save_folder)
if not save_folder.exists():
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!")
elif save_folder.is_file():
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback):
else:
save_on_exception = []

self.save_folder = Path(save_folder)
self.save_folder = save_folder
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
@@ -88,18 +108,15 @@ class CheckpointCallback(HasMonitorCallback):
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.timestamp_path)
# 该 folder 只在保存真的要发生的时候再创建。

def on_after_trainer_initialized(self, trainer, driver):
if self.save_topk is not None:
super().on_after_trainer_initialized(trainer, driver)
if self.save_topk is not None and trainer.evaluator is None:
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.")
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.")

def on_validate_end(self, trainer, results):
if len(results) == 0:
return
self._save_topk(trainer, results)

def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback):
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = deepcopy(self._topk_model)
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor
return states

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save in path: {timestamp_path}.")
logger.info(f"Resume to checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback):
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]

self._real_monitor = states["_real_monitor"]

def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
"""
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback):
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save_model'

@property
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback):
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param monitor: 监控的 metric 。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback):
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save'

@property
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):


+ 9
- 5
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
"""

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。
"""
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback):
states = {
'patience': self.patience,
'wait': self.wait,
'monitor': self.monitor,
'monitor_value': self.monitor_value
}
if not callable(self._real_monitor):
states['_real_monitor'] = self._real_monitor
return states

def on_load_checkpoint(self, trainer, states):
self.patience = states['patience']
self.wait = states['wait']
self.monitor = states['monitor']
self.monitor_value = float(states['monitor_value'])
if '_real_monitor' in states:
self._real_monitor = states['_real_monitor']

@property
def callback_name(self):
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}'
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}'


+ 3
- 2
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback):
"""
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。

:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。


+ 8
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback):

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是 monitor 的结果越大越好。
:param format_json: 是否格式化 json 再打印
"""
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback):

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果(
字典类型),返回一个 float 值作为 monitor 的结果。
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
"""


+ 18
- 34
fastNLP/core/controllers/evaluator.py View File

@@ -36,10 +36,10 @@ class Evaluator:
model,
dataloaders,
metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'single',
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
@@ -49,8 +49,8 @@ class Evaluator:
):
"""

:param dataloaders:
:param model:
:param dataloaders:
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的
metric ,torchmetrics,allennlpmetrics等。
:param driver: 使用 driver 。
@@ -58,14 +58,13 @@ class Evaluator:
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的
batch_step_fn 函数。
:param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数,
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数;
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param fp16: 是否使用 fp16 。
:param verbose: 是否打印 evaluate 的结果。
:param kwargs:
@@ -87,9 +86,11 @@ class Evaluator:

self.model = model
self.metrics = metrics

self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)

if dataloaders is None:
raise ValueError("Parameter `dataloaders` can not be None.")
self.dataloaders = dataloaders
self.device = device
self.verbose = verbose

@@ -97,21 +98,12 @@ class Evaluator:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn

self.mode = mode
assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'."

self.input_mapping = input_mapping
self.output_mapping = output_mapping

if not isinstance(dataloaders, dict):
dataloaders = {None: dataloaders}
if mode == "validate":
self._evaluate_step = self.driver.validate_step
self.driver.set_dataloader(validate_dataloaders=dataloaders)
else:
self._evaluate_step = self.driver.test_step
self.driver.set_dataloader(test_dataloaders=dataloaders)
self.mode = mode

self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
self.separator = kwargs.get('separator', '#')
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
@@ -123,10 +115,15 @@ class Evaluator:
self._metric_wrapper = None
_ = self.metrics_wrapper # 触发检查

assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders()
self.driver.setup()
self.driver.barrier()

if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._evaluate_step, self._evaluate_step_signature_fn = \
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self.evaluate_fn = evaluate_fn

self.dataloaders = {}
for name, dl in dataloaders.items(): # 替换为正确的 sampler
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
@@ -136,7 +133,6 @@ class Evaluator:
if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'

self.driver.check_evaluator_mode(self.mode)
self.driver.barrier()

def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
@@ -156,11 +152,6 @@ class Evaluator:
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."

if self.mode == 'validate':
assert self.driver.has_validate_dataloaders()
else:
assert self.driver.has_test_dataloaders()

metric_results = {}
self.reset()
evaluate_context = self.driver.get_evaluate_context()
@@ -235,13 +226,6 @@ class Evaluator:
f_rich_progress.destroy_task(self._rich_task_id)
delattr(self, '_rich_task_id')

@property
def eval_dataloaders(self):
if self.mode == "validate":
return self.driver.validate_dataloaders
else:
return self.driver.test_dataloaders

@property
def evaluate_batch_loop(self):
return self._evaluate_batch_loop
@@ -296,13 +280,13 @@ class Evaluator:

def evaluate_step(self, batch):
"""
将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
返回。

:param batch:
:return:
"""
outputs = self._evaluate_step(batch)
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs)
return outputs



+ 2
- 4
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop):
else lambda *args, **kwargs: None
dataloader = iter(dataloader)
indices = None
while True:
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
try:
trainer.on_fetch_data_begin()
batch = next(dataloader)
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop):
batch = trainer.move_data_to_device(batch)
except StopIteration:
break
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception
break
except BaseException as e:
if indices:
if indices and not isinstance(e, EarlyStopException):
logger.debug(f"The following exception happens when running on samples: {indices}")
raise e



+ 61
- 40
fastNLP/core/controllers/trainer.py View File

@@ -41,19 +41,20 @@ class Trainer(TrainerEventTrigger):
optimizers,
device: Optional[Union[int, List[int], str]] = "cpu",
n_epochs: int = 20,
validate_dataloaders=None,
evaluate_dataloaders=None,
batch_step_fn: Optional[Callable] = None,
validate_batch_step_fn: Optional[Callable] = None,
validate_mode: Union[str, callable] = 'validate',
evaluate_batch_step_fn: Optional[Callable] = None,
train_fn: Optional[str] = None,
evaluate_fn: Optional[str] = None,
callbacks: Union[List[Callback], Callback, None] = None,
metrics: Optional[dict] = None,
validate_every: Optional[Union[int, callable]] = -1,
evaluate_every: Optional[Union[int, Callable]] = -1,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1,
fp16: bool = False,
monitor: Union[str, callable] = None,
monitor: Union[str, Callable] = None,
larger_better: bool = True,
marker: Optional[str] = None,
**kwargs
@@ -79,19 +80,19 @@ class Trainer(TrainerEventTrigger):
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
5. None: 为None则不对模型进行任何处理;
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None;
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和
`batch`;默认为 None;
:param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
两个参数必须为 `evaluator` 和 `batch`;默认为 None;
:param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]
默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数,
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数;
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似;
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
@@ -105,10 +106,10 @@ class Trainer(TrainerEventTrigger):
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`;
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
@@ -136,10 +137,15 @@ class Trainer(TrainerEventTrigger):
else:
self.driver_name = driver.__class__.__name__
self.device = device
if train_dataloader is None:
raise ValueError("Parameter `train_dataloader` can not be None.")
self.train_dataloader = train_dataloader
self.evaluate_dataloaders = evaluate_dataloaders
self.optimizers = optimizers
self.fp16 = fp16
self.input_mapping = input_mapping
self.output_mapping = output_mapping
self.evaluate_fn = evaluate_fn

self.batch_step_fn = batch_step_fn
if batch_step_fn is not None:
@@ -168,13 +174,12 @@ class Trainer(TrainerEventTrigger):
optimizers=optimizers,
device=device,
n_epochs=n_epochs,
validate_dataloaders=validate_dataloaders,
batch_step_fn=batch_step_fn,
validate_batch_step_fn=validate_batch_step_fn,
validate_mode=validate_mode,
z=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
callbacks=callbacks,
metrics=metrics,
validate_every=validate_every,
validate_every=evaluate_every,
input_mapping=input_mapping,
output_mapping=output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call,
@@ -185,9 +190,6 @@ class Trainer(TrainerEventTrigger):
)
self.driver.set_optimizers(optimizers=optimizers)

if train_dataloader is not None:
self.driver.set_dataloader(train_dataloader=train_dataloader)

# 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto'))
# 添加所有的函数式 callbacks;
@@ -213,25 +215,25 @@ class Trainer(TrainerEventTrigger):
_dist_sampler = None

""" 设置内部的 Evaluator """
if metrics is None and validate_dataloaders is not None:
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")

if metrics is not None and validate_dataloaders is None:
if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")

self.evaluator = None
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None:
check_validate_every(validate_every)
if metrics is not None and evaluate_dataloaders is not None:
check_validate_every(evaluate_every)
self.evaluator = Evaluator(
model=model,
dataloaders=validate_dataloaders,
dataloaders=evaluate_dataloaders,
metrics=metrics,
driver=self.driver,
device=device,
batch_step_fn=validate_batch_step_fn,
mode=validate_mode,
batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
input_mapping=input_mapping,
output_mapping=output_mapping,
fp16=fp16,
@@ -241,12 +243,16 @@ class Trainer(TrainerEventTrigger):
)

self.metrics = metrics
self.validate_every = validate_every
self.validate_every = evaluate_every

assert self.driver.has_train_dataloader()
self.driver.setup()
self.driver.barrier()

if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
self.train_fn = train_fn

self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

@@ -273,6 +279,7 @@ class Trainer(TrainerEventTrigger):
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
:return:
"""

if catch_KeyboardInterrupt is None:
catch_KeyboardInterrupt = not self.driver.is_distributed()
else:
@@ -301,7 +308,7 @@ class Trainer(TrainerEventTrigger):

self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)
@@ -343,7 +350,8 @@ class Trainer(TrainerEventTrigger):
_validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res)

self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
if self.evaluator is not None:
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))

def step_validate(self):
"""
@@ -489,11 +497,6 @@ class Trainer(TrainerEventTrigger):
self.has_checked_train_batch_loop = True

""" Trainer 需要的一些 property """

@property
def train_dataloader(self):
return self.driver.train_dataloader

@property
def driver(self):
return self._driver
@@ -632,6 +635,8 @@ class Trainer(TrainerEventTrigger):
:param folder: 保存断点重训 states 的文件地址;
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置;
:param only_state_dict: 保存的 model 是否只包含了权重。
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。
"""
self.driver.barrier()
if isinstance(folder, str):
@@ -670,8 +675,6 @@ class Trainer(TrainerEventTrigger):
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \
self.batch_idx_in_epoch
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch

@@ -684,7 +687,7 @@ class Trainer(TrainerEventTrigger):

def train_step(self, batch):
with self.driver.auto_cast():
outputs = self.driver.train_step(batch)
outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs)
return outputs

@@ -814,6 +817,24 @@ class Trainer(TrainerEventTrigger):
def data_device(self):
return self.driver.data_device

""" dataloader property """

@property
def train_dataloader(self):
return self._train_dataloader

@train_dataloader.setter
def train_dataloader(self, train_dataloader):
self._train_dataloader = train_dataloader

@property
def evaluate_dataloaders(self):
return self._evaluate_dataloaders

@evaluate_dataloaders.setter
def evaluate_dataloaders(self, evaluate_dataloaders):
self._evaluate_dataloaders = evaluate_dataloaders






+ 3
- 3
fastNLP/core/controllers/utils/state.py View File

@@ -65,10 +65,10 @@ class TrainerState:
"""
n_epochs: Optional[int] = None # 无论如何重新算

cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = None # 断点重训
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0;
global_forward_batches: Optional[int] = 0 # 断点重训

batch_idx_in_epoch: Optional[int] = None # 断点重训
batch_idx_in_epoch: Optional[int] = 0 # 断点重训

num_batches_per_epoch: Optional[int] = None # 无论如何重新算



+ 1
- 1
fastNLP/core/controllers/utils/utils.py View File

@@ -128,6 +128,6 @@ class _TruncatedDataLoader:

def check_validate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

+ 29
- 100
fastNLP/core/drivers/driver.py View File

@@ -1,7 +1,7 @@
import os
import signal
import sys
from typing import Any, Sequence, List, Optional, Callable, Dict, Union
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
@@ -79,41 +79,44 @@ class Driver(ABC):
"""

@abstractmethod
def train_step(self, batch):
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
"""
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程;
如果检测到用户模型实现了 train_step
通过调用 `fn` 来实现训练时的前向传播过程;
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
函数;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
:param fn: 调用该函数进行一次计算。
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
"""
raise NotImplementedError("Each specific driver should implemented its own `train_step` function.")
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")

def validate_step(self, batch):
"""
通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
@abstractmethod
def get_model_call_fn(self, fn: str) -> Tuple:
"""
raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.")
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;

def test_step(self, batch):
"""
通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程;
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
"""
raise NotImplementedError("Each specific driver should implemented its own `test_step` function.")
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
函数,然后给出 warning;
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
可能需要额外标记最初传入 driver 的模型是哪种形式的;

def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
"""
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.")
raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.")

@property
def model(self):
@@ -123,80 +126,6 @@ class Driver(ABC):
def model(self, model):
self._model = model

@property
def train_dataloader(self):
return self._train_dataloader

@train_dataloader.setter
def train_dataloader(self, train_dataloader: Any):
self._train_dataloader = train_dataloader

@property
def validate_dataloaders(self):
return self._validate_dataloaders

@validate_dataloaders.setter
def validate_dataloaders(self, validate_dataloaders: Any):
self._validate_dataloaders = validate_dataloaders

@property
def test_dataloaders(self):
return self._test_dataloaders

@test_dataloaders.setter
def test_dataloaders(self, test_dataloaders: Any):
self._test_dataloaders = test_dataloaders

@property
def predict_dataloaders(self):
return self._predict_dataloaders

@predict_dataloaders.setter
def predict_dataloaders(self, predict_dataloaders: Any):
self._predict_dataloaders = predict_dataloaders

def set_dataloader(self, **kwargs):
r"""
设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上;

:param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置;
"""
if "train_dataloader" in kwargs:
self.train_dataloader = kwargs["train_dataloader"]
self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True)
if "validate_dataloaders" in kwargs:
self.validate_dataloaders = kwargs["validate_dataloaders"]
self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False)
if "test_dataloaders" in kwargs:
self.test_dataloaders = kwargs["test_dataloaders"]
self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False)
if "predict_dataloaders" in kwargs:
self.predict_dataloaders = kwargs["predict_dataloaders"]
self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False)

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的
行为是不相同的;

:param dataloader: 需要检测的输入的 `dataloader`;
:param dataloader_name:
"""
raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.")

def has_train_dataloader(self):
return "_train_dataloader" in self.__dict__

def has_validate_dataloaders(self):
return "_validate_dataloaders" in self.__dict__

def has_test_dataloaders(self):
return "_test_dataloaders" in self.__dict__

def has_predict_dataloaders(self):
return "_predict_dataloaders" in self.__dict__

@property
def optimizers(self) -> List:
r"""


+ 7
- 7
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -39,7 +39,7 @@ class JittorDriver(Driver):
self.grad_scaler = _grad_scaler()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader
# TODO: 是否允许传入Dataset?
if is_train:
@@ -64,18 +64,18 @@ class JittorDriver(Driver):
def check_evaluator_mode(self, mode: str):
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "validate_step"):
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")

def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None):


+ 5
- 5
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -35,8 +35,8 @@ class JittorSingleDriver(JittorDriver):
model = self.unwrap_model()
self._train_signature_fn = model.execute

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
@@ -49,9 +49,9 @@ class JittorSingleDriver(JittorDriver):
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()


+ 3
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -118,11 +118,11 @@ class PaddleFleetDriver(PaddleDriver):
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "test_step"):


+ 8
- 8
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -72,7 +72,7 @@ class PaddleDriver(Driver):
optimizer.clear_grad()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。
@@ -117,24 +117,24 @@ class PaddleDriver(Driver):

def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
"""
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "validate_step"):
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
"'evaluate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for"
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")

@staticmethod


+ 8
- 8
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -50,10 +50,10 @@ class PaddleSingleDriver(PaddleDriver):
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `validate_step` method, which we can not call actually, we "
"will call `forward` function instead of `validate_step` and you should note that.")
"implements the `evaluate_step` method, which we can not call actually, we "
"will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

@@ -73,8 +73,8 @@ class PaddleSingleDriver(PaddleDriver):
model = self.unwrap_model()
self._train_signature_fn = model.forward

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
@@ -87,9 +87,9 @@ class PaddleSingleDriver(PaddleDriver):
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()


+ 7
- 7
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -108,11 +108,11 @@ class _FleetWrappingModel(Layer):
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

@@ -131,7 +131,7 @@ class _FleetWrappingModel(Layer):
self._train_step = model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
@@ -144,7 +144,7 @@ class _FleetWrappingModel(Layer):
if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "validate_step"):
elif hasattr(model, "evaluate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
@@ -172,9 +172,9 @@ class _FleetWrappingModel(Layer):
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
else:
raise NotImplementedError("You should direct a concrete mode.")
raise NotImplementedError("You should direct a concrete evaluate_fn.")

class DummyGradScaler:
"""


+ 40
- 61
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -4,7 +4,7 @@ import __main__
import socket
import numpy as np
from time import sleep
from typing import List, Optional, Union, Dict
from typing import List, Optional, Union, Dict, Tuple, Callable
from functools import partial

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
@@ -21,8 +21,6 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import (
_DDPWrappingModel,
ForwardState,
_MODE_PARAMETER,
reset_seed,
replace_sampler,
replace_batch_sampler
@@ -158,10 +156,10 @@ class TorchDDPDriver(TorchDriver):
————————————————————————————————————————————————————————————————————————————————————————————————————————

3. _DDPWrappingModel 的作用;
因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel`
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的
forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。

4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理;

@@ -204,37 +202,6 @@ class TorchDDPDriver(TorchDriver):
# 我们就直接将 model_device 置为 None;
self.model_device = None

def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return step_fn(batch)

model = model.module
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._test_signature_fn = model.forward

# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
if isinstance(self._data_device, int):
@@ -253,7 +220,6 @@ class TorchDDPDriver(TorchDriver):
# world_size 表示的就是全局的显卡的数量;
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
self.global_rank = 0
self._configured = False # 防止重复调用 configure_ddp() 函数使用的

self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {})
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__)
@@ -268,8 +234,8 @@ class TorchDDPDriver(TorchDriver):
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)

# 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_setup = False
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;

def setup(self):
if self._has_setup:
@@ -341,24 +307,16 @@ class TorchDDPDriver(TorchDriver):
self._pids = self.tensor_to_numeric(self._pids)

def configure_ddp(self):
if not self._configured and not isinstance(self.model, DistributedDataParallel):
if not isinstance(self.model, DistributedDataParallel):
self.model = DistributedDataParallel(
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
_DDPWrappingModel(self.model), device_ids=[self.model_device.index],
**self._ddp_kwargs
)

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True
self._has_ddpwrapped = True

def open_subprocess(self):
if self.local_rank == 0:
# self._consensus_file = Path(tempfile.mkstemp()[1])
# self._consensus_file.unlink()

# Script called as `python a/b/c.py`
if __main__.__spec__ is None: # pragma: no-cover
# pull out the commands used to run the script and resolve the abs file path
@@ -432,18 +390,39 @@ class TorchDDPDriver(TorchDriver):
return self._data_device
return self.model_device

def train_step(self, batch):
# 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel';
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN})
return self._train_step(batch)

def validate_step(self, batch):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE})
return self._validate_step(batch)

def test_step(self, batch):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if self._has_ddpwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call)
else:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
model = self.unwrap_model()
if self._has_ddpwrapped:
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
"`DistributedDataParallel` model, which means that we will only call model.forward "
"function when we are in forward propagation.")

return self.model, model.forward

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
reproducible: bool = False):


+ 26
- 70
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -1,5 +1,5 @@
import os
from typing import Dict, Union
from typing import Dict, Union, Callable, Tuple, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
@@ -42,84 +42,40 @@ class TorchSingleDriver(TorchDriver):
self.global_rank = 0
self.world_size = 1

if isinstance(model, DataParallel):
model = self.unwrap_model()
if hasattr(model, "train_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
# 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的;
model = self.unwrap_model()
self._train_signature_fn = model.forward

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.forward

def setup(self):
if self.model_device is not None:
self.model.to(self.model_device)

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._train_step(batch)
return fn(batch)

def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
if isinstance(self.model, DataParallel):
model = self.unwrap_model()
if hasattr(model, fn):
logger.warning("Notice your model is a `DataParallel` model. And your model also implements the "
f"`{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
f"`DataParallel` model, which means that we will only call model.forward function "
f"when we are in forward propagation.")

return self.model, model.forward
else:
return self._test_step(batch)
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return self.model, self.model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):


+ 2
- 18
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -81,7 +81,7 @@ class TorchDriver(Driver):
self.grad_scaler.update()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.")
@@ -108,23 +108,6 @@ class TorchDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
f"not {type(each_optimizer)}.")

def check_evaluator_mode(self, mode: str):
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"):
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"'test_step'.")

@staticmethod
def tensor_to_numeric(tensor, reduce=None):
if tensor is None:
@@ -216,6 +199,7 @@ class TorchDriver(Driver):
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')


+ 6
- 51
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -90,14 +90,11 @@ class ForwardState(IntEnum):
PREDICT = 3


_MODE_PARAMETER = "_forward_state"


class _DDPWrappingModel(Module):
"""
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数;
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行;
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等;
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等;
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同;

@@ -109,60 +106,18 @@ class _DDPWrappingModel(Module):
super(_DDPWrappingModel, self).__init__()
self.model = model

if hasattr(model, "train_step"):
self._train_step = model.train_step
self._train_signature_fn = None
else:
self._train_step = model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
self._validate_step = model.test_step
self._validate_signature_fn = None
else:
self._validate_step = model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "validate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
self._test_step = model
self._test_signature_fn = model.forward

def forward(self, batch, **kwargs) -> Dict:
"""
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
"""
forward_state = kwargs.pop(_MODE_PARAMETER)
fn = kwargs.pop("fastnlp_fn")
signature_fn = kwargs.pop("fastnlp_signature_fn")
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
raise NotImplementedError("You should direct a concrete mode.")
return fn(batch)


class DummyGradScaler:


+ 5
- 5
fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py View File

@@ -55,8 +55,8 @@ class TorchPaddleDriver(Driver):
self._train_step = self.model
self._train_signature_fn = self.model.forward

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
@@ -68,8 +68,8 @@ class TorchPaddleDriver(Driver):
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.forward
else:
self._test_step = self.model
@@ -81,7 +81,7 @@ class TorchPaddleDriver(Driver):
self.model.to(self.model_device)

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:
if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)):
raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.")


+ 3
- 3
fastNLP/core/log/logger.py View File

@@ -211,9 +211,9 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]]
raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.")

if not isinstance(mode, str):
raise TypeError("Parameter 'mode' can only be `str` type.")
raise TypeError("Parameter 'evaluate_fn' can only be `str` type.")
if mode not in {"w", "a"}:
raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').")
raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').")

for h in _logger.handlers:
if isinstance(h, logging.FileHandler):
@@ -230,7 +230,7 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]]
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

# 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新
# 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新
# 覆盖掉原文件,而是会接着上一次的 log 继续添加;
# 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉;
if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0:


+ 8
- 8
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2(
device=4,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -626,7 +626,7 @@ def test_trainer_checkpoint_callback_2(
train_dataloader=test_bert_dataloader_train,
optimizers=test_bert_optimizers,

validate_dataloaders=test_bert_dataloader_validate,
evaluate_dataloaders=test_bert_dataloader_validate,
input_mapping=bert_input_mapping,
output_mapping=bert_output_mapping,
metrics={"acc": acc},
@@ -700,7 +700,7 @@ def test_trainer_checkpoint_callback_2(
train_dataloader=test_bert_dataloader_train,
optimizers=test_bert_optimizers,

validate_dataloaders=test_bert_dataloader_validate,
evaluate_dataloaders=test_bert_dataloader_validate,
input_mapping=bert_input_mapping,
output_mapping=bert_output_mapping,
metrics={"acc": acc},


+ 1
- 1
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -92,7 +92,7 @@ def test_load_best_model_callback(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']},
metrics=model_and_optimizers.metrics,


+ 1
- 1
tests/core/controllers/_test_distributed_launch_torch_1.py View File

@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=None,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
evaluate_dataloaders=validate_dataloaders,
metrics=metrics,

n_epochs=2,


+ 1
- 1
tests/core/controllers/_test_distributed_launch_torch_2.py View File

@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=None,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
evaluate_dataloaders=validate_dataloaders,
metrics=metrics,

n_epochs=2,


+ 1
- 1
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -82,7 +82,7 @@ def test_trainer_event_trigger(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,


+ 2
- 2
tests/core/controllers/test_trainer_fleet.py View File

@@ -64,8 +64,8 @@ def test_trainer_fleet(
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
evaluate_dataloaders=validate_dataloaders,
evaluate_every=validate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,


+ 2
- 2
tests/core/controllers/test_trainer_fleet_outside.py View File

@@ -70,8 +70,8 @@ def test_trainer_fleet(
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
evaluate_dataloaders=validate_dataloaders,
evaluate_every=validate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,


+ 1
- 1
tests/core/controllers/test_trainer_paddle.py View File

@@ -96,4 +96,4 @@ def test_trainer_paddle(
n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()
trainer.run()

+ 8
- 8
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -98,16 +98,16 @@ def model_and_optimizers(request):


# 测试一下普通的情况;
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]])
@pytest.mark.parametrize("validate_every", [-3])
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100])
@magic_argv_env_context
def test_trainer_torch_with_evaluator(
model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
validate_every,
evaluate_every,
n_epochs=10,
):
trainer = Trainer(
@@ -116,11 +116,11 @@ def test_trainer_torch_with_evaluator(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
validate_every=validate_every,
evaluate_every=evaluate_every,

n_epochs=n_epochs,
callbacks=callbacks,
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -193,14 +193,14 @@ def test_trainer_validate_every(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
output_from_new_proc="all",
validate_every=validate_every
evaluate_every=validate_every
)

trainer.run()


+ 5
- 5
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps(
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps(

optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc(

optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
@@ -267,7 +267,7 @@ def test_trainer_on_exception(

optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,


+ 7
- 7
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -423,12 +423,12 @@ class TestPaddleDriverFunctions:
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
@@ -436,7 +436,7 @@ class TestPaddleDriverFunctions:
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legality_in_test(self):
"""
@@ -447,7 +447,7 @@ class TestPaddleDriverFunctions:
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
@@ -455,12 +455,12 @@ class TestPaddleDriverFunctions:
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
@@ -473,7 +473,7 @@ class TestPaddleDriverFunctions:
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""


+ 1
- 0
tests/core/utils/test_utils.py View File

@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters:


def test_get_fun_msg():
# 测试运行
def demo(x):
pass


+ 15
- 1
tests/helpers/datasets/normal_data.py View File

@@ -1,3 +1,6 @@
import numpy as np


class NormalIterator:
def __init__(self, num_of_data=1000):
self._num_of_data = num_of_data
@@ -15,4 +18,15 @@ class NormalIterator:
return self._data

def __len__(self):
return self._num_of_data
return self._num_of_data


class RandomDataset:
def __init__(self, num_data=10):
self.data = np.random.rand(num_data)

def __len__(self):
return len(self.data)

def __getitem__(self, item):
return self.data[item]

+ 1
- 1
tests/helpers/models/torch_model.py View File

@@ -28,7 +28,7 @@ class TorchNormalModel_Classification_1(nn.Module):
x = self(x)
return {"loss": self.loss_fn(x, y)}

def validate_step(self, x, y):
def evaluate_step(self, x, y):
"""
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"};
"""


Loading…
Cancel
Save