@@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): | |||||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | if (self.larger_better and monitor_value1 > monitor_value2) or \ | ||||
(not self.larger_better and monitor_value1 < monitor_value2): | (not self.larger_better and monitor_value1 < monitor_value2): | ||||
better = True | better = True | ||||
return better | |||||
return better | |||||
@property | |||||
def monitor_name(self): | |||||
""" | |||||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||||
:return: | |||||
""" | |||||
if callable(self.monitor): | |||||
try: | |||||
monitor_name = self.monitor.__qualname__ | |||||
except: | |||||
monitor_name = self.monitor.__name__ | |||||
elif self.monitor is None: | |||||
return None | |||||
else: | |||||
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||||
monitor_name = str(self.monitor) | |||||
return monitor_name |
@@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
class CheckpointCallback(HasMonitorCallback): | class CheckpointCallback(HasMonitorCallback): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
monitor, | |||||
monitor:Optional[Union[str, Callable]]=None, | |||||
save_folder: Optional[Union[str, Path]] = None, | save_folder: Optional[Union[str, Path]] = None, | ||||
save_every_n_epochs: Optional[int] = None, | save_every_n_epochs: Optional[int] = None, | ||||
save_every_n_batches: Optional[int] = None, | save_every_n_batches: Optional[int] = None, | ||||
save_last: bool = True, | |||||
save_last: bool = False, | |||||
save_topk: Optional[int] = None, | save_topk: Optional[int] = None, | ||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | ||||
larger_better: bool = True, | larger_better: bool = True, | ||||
@@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): | |||||
model_save_fn: Optional[Callable] = None, | model_save_fn: Optional[Callable] = None, | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
""" | |||||
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better, | super().__init__(monitor=monitor, larger_better=larger_better, | ||||
must_have_monitor=save_topk is not None) | must_have_monitor=save_topk is not None) | ||||
if save_folder is None: | if save_folder is None: | ||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | "Parameter `path` is None, and we will use the current work directory to find and load your model.") | ||||
save_folder = Path.cwd() | save_folder = Path.cwd() | ||||
save_folder = Path(save_folder) | |||||
if not save_folder.exists(): | if not save_folder.exists(): | ||||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | ||||
elif save_folder.is_file(): | elif save_folder.is_file(): | ||||
@@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): | |||||
else: | else: | ||||
save_on_exception = [] | save_on_exception = [] | ||||
self.save_folder = Path(save_folder) | |||||
self.save_folder = save_folder | |||||
self.save_every_n_epochs = save_every_n_epochs | self.save_every_n_epochs = save_every_n_epochs | ||||
self.save_every_n_batches = save_every_n_batches | self.save_every_n_batches = save_every_n_batches | ||||
self.save_last = save_last | self.save_last = save_last | ||||
@@ -88,18 +108,15 @@ class CheckpointCallback(HasMonitorCallback): | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | ||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | ||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||||
synchronize_mkdir(self.timestamp_path) | |||||
# 该 folder 只在保存真的要发生的时候再创建。 | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
if self.save_topk is not None and trainer.evaluator is None: | if self.save_topk is not None and trainer.evaluator is None: | ||||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") | |||||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results) == 0: | |||||
return | |||||
self._save_topk(trainer, results) | self._save_topk(trainer, results) | ||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
@@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): | |||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | states['timestamp_path'] = str(self.timestamp_path.absolute()) | ||||
states['_topk_model'] = deepcopy(self._topk_model) | states['_topk_model'] = deepcopy(self._topk_model) | ||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | states['save_topk'] = 0 if self.save_topk is None else self.save_topk | ||||
states['_real_monitor'] = self._real_monitor | |||||
if isinstance(self._real_monitor, str): | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
timestamp_path = states['timestamp_path'] | timestamp_path = states['timestamp_path'] | ||||
if not os.path.exists(timestamp_path): | if not os.path.exists(timestamp_path): | ||||
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||||
f" {self.timestamp_path.absolute()}.") | f" {self.timestamp_path.absolute()}.") | ||||
else: | else: | ||||
logger.info(f"Resume to save in path: {timestamp_path}.") | |||||
logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||||
self.timestamp_path = Path(timestamp_path) | self.timestamp_path = Path(timestamp_path) | ||||
_topk_model = states['_topk_model'] | _topk_model = states['_topk_model'] | ||||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | ||||
@@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): | |||||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | ||||
f"as {save_topk}." | f"as {save_topk}." | ||||
self._topk_model.update(self._topk_model) | self._topk_model.update(self._topk_model) | ||||
self._real_monitor = states["real_monitor"] | |||||
self._real_monitor = states["_real_monitor"] | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | ||||
""" | """ | ||||
@@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | ||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
""" | """ | ||||
@property | @property | ||||
def save_fn_name(self): | def save_fn_name(self): | ||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save_model' | return 'save_model' | ||||
@property | @property | ||||
@@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | ||||
:return: | :return: | ||||
""" | """ | ||||
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
@@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | ||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
""" | """ | ||||
@property | @property | ||||
def save_fn_name(self): | def save_fn_name(self): | ||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save' | return 'save' | ||||
@property | @property | ||||
@@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | ||||
:return: | :return: | ||||
""" | """ | ||||
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
@@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | ||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 validate 不没有提升就停止。 | :param patience: 多少次 validate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
states = { | states = { | ||||
'patience': self.patience, | 'patience': self.patience, | ||||
'wait': self.wait, | 'wait': self.wait, | ||||
'monitor': self.monitor, | |||||
'monitor_value': self.monitor_value | 'monitor_value': self.monitor_value | ||||
} | } | ||||
if not callable(self._real_monitor): | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states): | def on_load_checkpoint(self, trainer, states): | ||||
self.patience = states['patience'] | self.patience = states['patience'] | ||||
self.wait = states['wait'] | self.wait = states['wait'] | ||||
self.monitor = states['monitor'] | |||||
self.monitor_value = float(states['monitor_value']) | self.monitor_value = float(states['monitor_value']) | ||||
if '_real_monitor' in states: | |||||
self._real_monitor = states['_real_monitor'] | |||||
@property | |||||
def callback_name(self): | def callback_name(self): | ||||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||||
return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' | |||||
@@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||||
:param format_json: 是否格式化 json 再打印 | |||||
""" | """ | ||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | ||||
self.print_every = print_every | self.print_every = print_every | ||||
@@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -36,10 +36,10 @@ class Evaluator: | |||||
model, | model, | ||||
dataloaders, | dataloaders, | ||||
metrics: Optional[Union[Dict, Metric]] = None, | metrics: Optional[Union[Dict, Metric]] = None, | ||||
driver: Union[str, Driver] = 'single', | |||||
driver: Union[str, Driver] = 'torch', | |||||
device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
@@ -49,8 +49,8 @@ class Evaluator: | |||||
): | ): | ||||
""" | """ | ||||
:param dataloaders: | |||||
:param model: | |||||
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | |||||
:param dataloaders: 待评测的数据集。 | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | ||||
metric ,torchmetrics,allennlpmetrics等。 | metric ,torchmetrics,allennlpmetrics等。 | ||||
:param driver: 使用 driver 。 | :param driver: 使用 driver 。 | ||||
@@ -58,14 +58,13 @@ class Evaluator: | |||||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | ||||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | ||||
batch_step_fn 函数。 | batch_step_fn 函数。 | ||||
:param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; | |||||
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | ||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | ||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | |||||
:param fp16: 是否使用 fp16 。 | :param fp16: 是否使用 fp16 。 | ||||
:param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
:param kwargs: | :param kwargs: | ||||
@@ -87,9 +86,11 @@ class Evaluator: | |||||
self.model = model | self.model = model | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | ||||
if dataloaders is None: | |||||
raise ValueError("Parameter `dataloaders` can not be None.") | |||||
self.dataloaders = dataloaders | |||||
self.device = device | self.device = device | ||||
self.verbose = verbose | self.verbose = verbose | ||||
@@ -97,21 +98,12 @@ class Evaluator: | |||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
self.mode = mode | |||||
assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'." | |||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
if not isinstance(dataloaders, dict): | if not isinstance(dataloaders, dict): | ||||
dataloaders = {None: dataloaders} | dataloaders = {None: dataloaders} | ||||
if mode == "validate": | |||||
self._evaluate_step = self.driver.validate_step | |||||
self.driver.set_dataloader(validate_dataloaders=dataloaders) | |||||
else: | |||||
self._evaluate_step = self.driver.test_step | |||||
self.driver.set_dataloader(test_dataloaders=dataloaders) | |||||
self.mode = mode | |||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | ||||
self.separator = kwargs.get('separator', '#') | self.separator = kwargs.get('separator', '#') | ||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | ||||
@@ -123,10 +115,15 @@ class Evaluator: | |||||
self._metric_wrapper = None | self._metric_wrapper = None | ||||
_ = self.metrics_wrapper # 触发检查 | _ = self.metrics_wrapper # 触发检查 | ||||
assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders() | |||||
self.driver.setup() | self.driver.setup() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | |||||
self._evaluate_step, self._evaluate_step_signature_fn = \ | |||||
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||||
self.evaluate_fn = evaluate_fn | |||||
self.dataloaders = {} | self.dataloaders = {} | ||||
for name, dl in dataloaders.items(): # 替换为正确的 sampler | for name, dl in dataloaders.items(): # 替换为正确的 sampler | ||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) | dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) | ||||
@@ -136,7 +133,6 @@ class Evaluator: | |||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | ||||
self.driver.check_evaluator_mode(self.mode) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | ||||
@@ -156,11 +152,6 @@ class Evaluator: | |||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | ||||
if self.mode == 'validate': | |||||
assert self.driver.has_validate_dataloaders() | |||||
else: | |||||
assert self.driver.has_test_dataloaders() | |||||
metric_results = {} | metric_results = {} | ||||
self.reset() | self.reset() | ||||
evaluate_context = self.driver.get_evaluate_context() | evaluate_context = self.driver.get_evaluate_context() | ||||
@@ -235,13 +226,6 @@ class Evaluator: | |||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
@property | |||||
def eval_dataloaders(self): | |||||
if self.mode == "validate": | |||||
return self.driver.validate_dataloaders | |||||
else: | |||||
return self.driver.test_dataloaders | |||||
@property | @property | ||||
def evaluate_batch_loop(self): | def evaluate_batch_loop(self): | ||||
return self._evaluate_batch_loop | return self._evaluate_batch_loop | ||||
@@ -296,13 +280,13 @@ class Evaluator: | |||||
def evaluate_step(self, batch): | def evaluate_step(self, batch): | ||||
""" | """ | ||||
将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||||
返回。 | 返回。 | ||||
:param batch: | :param batch: | ||||
:return: | :return: | ||||
""" | """ | ||||
outputs = self._evaluate_step(batch) | |||||
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | |||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | outputs = match_and_substitute_params(self.output_mapping, outputs) | ||||
return outputs | return outputs | ||||
@@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): | |||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
indices = None | indices = None | ||||
while True: | |||||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | |||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
@@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): | |||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||||
break | |||||
except BaseException as e: | except BaseException as e: | ||||
if indices: | |||||
if indices and not isinstance(e, EarlyStopException): | |||||
logger.debug(f"The following exception happens when running on samples: {indices}") | logger.debug(f"The following exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
@@ -41,19 +41,20 @@ class Trainer(TrainerEventTrigger): | |||||
optimizers, | optimizers, | ||||
device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
n_epochs: int = 20, | n_epochs: int = 20, | ||||
validate_dataloaders=None, | |||||
evaluate_dataloaders=None, | |||||
batch_step_fn: Optional[Callable] = None, | batch_step_fn: Optional[Callable] = None, | ||||
validate_batch_step_fn: Optional[Callable] = None, | |||||
validate_mode: Union[str, callable] = 'validate', | |||||
evaluate_batch_step_fn: Optional[Callable] = None, | |||||
train_fn: Optional[str] = None, | |||||
evaluate_fn: Optional[str] = None, | |||||
callbacks: Union[List[Callback], Callback, None] = None, | callbacks: Union[List[Callback], Callback, None] = None, | ||||
metrics: Optional[dict] = None, | metrics: Optional[dict] = None, | ||||
validate_every: Optional[Union[int, callable]] = -1, | |||||
evaluate_every: Optional[Union[int, Callable]] = -1, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
monitor: Union[str, callable] = None, | |||||
monitor: Union[str, Callable] = None, | |||||
larger_better: bool = True, | larger_better: bool = True, | ||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
@@ -79,19 +80,21 @@ class Trainer(TrainerEventTrigger): | |||||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | ||||
5. None: 为None则不对模型进行任何处理; | 5. None: 为None则不对模型进行任何处理; | ||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | ||||
:param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||||
为 None; | 为 None; | ||||
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 | :param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 | ||||
`batch`;默认为 None; | `batch`;默认为 None; | ||||
:param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||||
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | 两个参数必须为 `evaluator` 和 `batch`;默认为 None; | ||||
:param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]; | |||||
默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; | |||||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, | |||||
则使用模型默认的前向传播函数。 | |||||
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; | |||||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);如果该值为 None ,将首先尝试寻找模型中是否有 | |||||
evaluate_step 这个函数,如果没有则使用 forward 函数。 | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | ||||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | ||||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | ||||
@@ -105,10 +108,10 @@ class Trainer(TrainerEventTrigger): | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | ||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||||
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
@@ -136,10 +139,15 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
self.driver_name = driver.__class__.__name__ | self.driver_name = driver.__class__.__name__ | ||||
self.device = device | self.device = device | ||||
if train_dataloader is None: | |||||
raise ValueError("Parameter `train_dataloader` can not be None.") | |||||
self.train_dataloader = train_dataloader | |||||
self.evaluate_dataloaders = evaluate_dataloaders | |||||
self.optimizers = optimizers | self.optimizers = optimizers | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
self.evaluate_fn = evaluate_fn | |||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
@@ -168,13 +176,13 @@ class Trainer(TrainerEventTrigger): | |||||
optimizers=optimizers, | optimizers=optimizers, | ||||
device=device, | device=device, | ||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
validate_dataloaders=validate_dataloaders, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
batch_step_fn=batch_step_fn, | batch_step_fn=batch_step_fn, | ||||
validate_batch_step_fn=validate_batch_step_fn, | |||||
validate_mode=validate_mode, | |||||
evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, | |||||
callbacks=callbacks, | callbacks=callbacks, | ||||
metrics=metrics, | metrics=metrics, | ||||
validate_every=validate_every, | |||||
evaluate_every=evaluate_every, | |||||
input_mapping=input_mapping, | input_mapping=input_mapping, | ||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
model_wo_auto_param_call=model_wo_auto_param_call, | model_wo_auto_param_call=model_wo_auto_param_call, | ||||
@@ -185,9 +193,6 @@ class Trainer(TrainerEventTrigger): | |||||
) | ) | ||||
self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
if train_dataloader is not None: | |||||
self.driver.set_dataloader(train_dataloader=train_dataloader) | |||||
# 初始化 callback manager; | # 初始化 callback manager; | ||||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | ||||
# 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
@@ -213,25 +218,25 @@ class Trainer(TrainerEventTrigger): | |||||
_dist_sampler = None | _dist_sampler = None | ||||
""" 设置内部的 Evaluator """ | """ 设置内部的 Evaluator """ | ||||
if metrics is None and validate_dataloaders is not None: | |||||
if metrics is None and evaluate_dataloaders is not None: | |||||
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") | raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") | ||||
if metrics is not None and validate_dataloaders is None: | |||||
if metrics is not None and evaluate_dataloaders is None: | |||||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | ||||
self.evaluator = None | self.evaluator = None | ||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and validate_dataloaders is not None: | |||||
check_validate_every(validate_every) | |||||
if metrics is not None and evaluate_dataloaders is not None: | |||||
check_validate_every(evaluate_every) | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
dataloaders=validate_dataloaders, | |||||
dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | metrics=metrics, | ||||
driver=self.driver, | driver=self.driver, | ||||
device=device, | device=device, | ||||
batch_step_fn=validate_batch_step_fn, | |||||
mode=validate_mode, | |||||
batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, | |||||
input_mapping=input_mapping, | input_mapping=input_mapping, | ||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
fp16=fp16, | fp16=fp16, | ||||
@@ -241,12 +246,16 @@ class Trainer(TrainerEventTrigger): | |||||
) | ) | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = validate_every | |||||
self.validate_every = evaluate_every | |||||
assert self.driver.has_train_dataloader() | |||||
self.driver.setup() | self.driver.setup() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if train_fn is not None and not isinstance(train_fn, str): | |||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | |||||
self.train_fn = train_fn | |||||
self.dataloader = self.train_dataloader | self.dataloader = self.train_dataloader | ||||
self.driver.set_deterministic_dataloader(self.dataloader) | self.driver.set_deterministic_dataloader(self.dataloader) | ||||
@@ -273,6 +282,7 @@ class Trainer(TrainerEventTrigger): | |||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | ||||
:return: | :return: | ||||
""" | """ | ||||
if catch_KeyboardInterrupt is None: | if catch_KeyboardInterrupt is None: | ||||
catch_KeyboardInterrupt = not self.driver.is_distributed() | catch_KeyboardInterrupt = not self.driver.is_distributed() | ||||
else: | else: | ||||
@@ -301,7 +311,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.num_batches_per_epoch = len(self.dataloader) | self.num_batches_per_epoch = len(self.dataloader) | ||||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | self.total_batches = self.num_batches_per_epoch * self.n_epochs | ||||
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | |||||
self.on_train_begin() | self.on_train_begin() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
@@ -343,7 +353,8 @@ class Trainer(TrainerEventTrigger): | |||||
_validate_res: dict = validate_fn() | _validate_res: dict = validate_fn() | ||||
trainer.on_validate_end(_validate_res) | trainer.on_validate_end(_validate_res) | ||||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
if self.evaluator is not None: | |||||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
def step_validate(self): | def step_validate(self): | ||||
""" | """ | ||||
@@ -489,11 +500,6 @@ class Trainer(TrainerEventTrigger): | |||||
self.has_checked_train_batch_loop = True | self.has_checked_train_batch_loop = True | ||||
""" Trainer 需要的一些 property """ | """ Trainer 需要的一些 property """ | ||||
@property | |||||
def train_dataloader(self): | |||||
return self.driver.train_dataloader | |||||
@property | @property | ||||
def driver(self): | def driver(self): | ||||
return self._driver | return self._driver | ||||
@@ -632,6 +638,8 @@ class Trainer(TrainerEventTrigger): | |||||
:param folder: 保存断点重训 states 的文件地址; | :param folder: 保存断点重训 states 的文件地址; | ||||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | ||||
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | ||||
:param only_state_dict: 保存的 model 是否只包含了权重。 | |||||
:param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 | |||||
""" | """ | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
@@ -670,8 +678,6 @@ class Trainer(TrainerEventTrigger): | |||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||||
self.batch_idx_in_epoch | |||||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | ||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
@@ -684,7 +690,7 @@ class Trainer(TrainerEventTrigger): | |||||
def train_step(self, batch): | def train_step(self, batch): | ||||
with self.driver.auto_cast(): | with self.driver.auto_cast(): | ||||
outputs = self.driver.train_step(batch) | |||||
outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn) | |||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | outputs = match_and_substitute_params(self.output_mapping, outputs) | ||||
return outputs | return outputs | ||||
@@ -814,6 +820,24 @@ class Trainer(TrainerEventTrigger): | |||||
def data_device(self): | def data_device(self): | ||||
return self.driver.data_device | return self.driver.data_device | ||||
""" dataloader property """ | |||||
@property | |||||
def train_dataloader(self): | |||||
return self._train_dataloader | |||||
@train_dataloader.setter | |||||
def train_dataloader(self, train_dataloader): | |||||
self._train_dataloader = train_dataloader | |||||
@property | |||||
def evaluate_dataloaders(self): | |||||
return self._evaluate_dataloaders | |||||
@evaluate_dataloaders.setter | |||||
def evaluate_dataloaders(self, evaluate_dataloaders): | |||||
self._evaluate_dataloaders = evaluate_dataloaders | |||||
@@ -65,10 +65,10 @@ class TrainerState: | |||||
""" | """ | ||||
n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; | |||||
global_forward_batches: Optional[int] = None # 断点重训 | |||||
cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; | |||||
global_forward_batches: Optional[int] = 0 # 断点重训 | |||||
batch_idx_in_epoch: Optional[int] = None # 断点重训 | |||||
batch_idx_in_epoch: Optional[int] = 0 # 断点重训 | |||||
num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | ||||
@@ -128,6 +128,6 @@ class _TruncatedDataLoader: | |||||
def check_validate_every(validate_every): | def check_validate_every(validate_every): | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
if callable(validate_every): | if callable(validate_every): | ||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | _check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -1,7 +1,7 @@ | |||||
import os | import os | ||||
import signal | import signal | ||||
import sys | import sys | ||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union | |||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple | |||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from datetime import datetime | from datetime import datetime | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -79,41 +79,44 @@ class Driver(ABC): | |||||
""" | """ | ||||
@abstractmethod | @abstractmethod | ||||
def train_step(self, batch): | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
""" | """ | ||||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | |||||
如果检测到用户模型实现了 train_step | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
函数; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | ||||
:return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `train_step` function.") | |||||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | |||||
def validate_step(self, batch): | |||||
""" | |||||
通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
@abstractmethod | |||||
def get_model_call_fn(self, fn: str) -> Tuple: | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.") | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
def test_step(self, batch): | |||||
""" | |||||
通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `test_step` function.") | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||||
def check_evaluator_mode(self, mode: str): | |||||
r""" | |||||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||||
我们应当提醒用户这一行为; | |||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") | |||||
raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.") | |||||
@property | @property | ||||
def model(self): | def model(self): | ||||
@@ -123,80 +126,6 @@ class Driver(ABC): | |||||
def model(self, model): | def model(self, model): | ||||
self._model = model | self._model = model | ||||
@property | |||||
def train_dataloader(self): | |||||
return self._train_dataloader | |||||
@train_dataloader.setter | |||||
def train_dataloader(self, train_dataloader: Any): | |||||
self._train_dataloader = train_dataloader | |||||
@property | |||||
def validate_dataloaders(self): | |||||
return self._validate_dataloaders | |||||
@validate_dataloaders.setter | |||||
def validate_dataloaders(self, validate_dataloaders: Any): | |||||
self._validate_dataloaders = validate_dataloaders | |||||
@property | |||||
def test_dataloaders(self): | |||||
return self._test_dataloaders | |||||
@test_dataloaders.setter | |||||
def test_dataloaders(self, test_dataloaders: Any): | |||||
self._test_dataloaders = test_dataloaders | |||||
@property | |||||
def predict_dataloaders(self): | |||||
return self._predict_dataloaders | |||||
@predict_dataloaders.setter | |||||
def predict_dataloaders(self, predict_dataloaders: Any): | |||||
self._predict_dataloaders = predict_dataloaders | |||||
def set_dataloader(self, **kwargs): | |||||
r""" | |||||
设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上; | |||||
:param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置; | |||||
""" | |||||
if "train_dataloader" in kwargs: | |||||
self.train_dataloader = kwargs["train_dataloader"] | |||||
self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||||
if "validate_dataloaders" in kwargs: | |||||
self.validate_dataloaders = kwargs["validate_dataloaders"] | |||||
self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False) | |||||
if "test_dataloaders" in kwargs: | |||||
self.test_dataloaders = kwargs["test_dataloaders"] | |||||
self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False) | |||||
if "predict_dataloaders" in kwargs: | |||||
self.predict_dataloaders = kwargs["predict_dataloaders"] | |||||
self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False) | |||||
@staticmethod | |||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
r""" | |||||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||||
行为是不相同的; | |||||
:param dataloader: 需要检测的输入的 `dataloader`; | |||||
:param dataloader_name: | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.") | |||||
def has_train_dataloader(self): | |||||
return "_train_dataloader" in self.__dict__ | |||||
def has_validate_dataloaders(self): | |||||
return "_validate_dataloaders" in self.__dict__ | |||||
def has_test_dataloaders(self): | |||||
return "_test_dataloaders" in self.__dict__ | |||||
def has_predict_dataloaders(self): | |||||
return "_predict_dataloaders" in self.__dict__ | |||||
@property | @property | ||||
def optimizers(self) -> List: | def optimizers(self) -> List: | ||||
r""" | r""" | ||||
@@ -39,7 +39,7 @@ class JittorDriver(Driver): | |||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
# 在fastnlp中实现了JittorDataLoader | # 在fastnlp中实现了JittorDataLoader | ||||
# TODO: 是否允许传入Dataset? | # TODO: 是否允许传入Dataset? | ||||
if is_train: | if is_train: | ||||
@@ -64,18 +64,18 @@ class JittorDriver(Driver): | |||||
def check_evaluator_mode(self, mode: str): | def check_evaluator_mode(self, mode: str): | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | |||||
if not hasattr(model, "evaluate_step"): | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning_once( | logger.warning_once( | ||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||||
"'validate_step'.") | |||||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||||
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" | |||||
"'evaluate_step'.") | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | ||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||||
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" | |||||
"'test_step'.") | "'test_step'.") | ||||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | ||||
@@ -35,8 +35,8 @@ class JittorSingleDriver(JittorDriver): | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
self._train_signature_fn = model.execute | self._train_signature_fn = model.execute | ||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
if hasattr(self.model, "evaluate_step"): | |||||
self._validate_step = self.model.evaluate_step | |||||
self._validate_signature_fn = None | self._validate_signature_fn = None | ||||
elif hasattr(self.model, "test_step"): | elif hasattr(self.model, "test_step"): | ||||
self._validate_step = self.model.test_step | self._validate_step = self.model.test_step | ||||
@@ -49,9 +49,9 @@ class JittorSingleDriver(JittorDriver): | |||||
if hasattr(self.model, "test_step"): | if hasattr(self.model, "test_step"): | ||||
self._test_step = self.model.test_step | self._test_step = self.model.test_step | ||||
self._test_signature_fn = None | self._test_signature_fn = None | ||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
self._test_signature_fn = self.model.validate_step | |||||
elif hasattr(self.model, "evaluate_step"): | |||||
self._test_step = self.model.evaluate_step | |||||
self._test_signature_fn = self.model.evaluate_step | |||||
else: | else: | ||||
self._test_step = self.model | self._test_step = self.model | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
@@ -118,11 +118,11 @@ class PaddleFleetDriver(PaddleDriver): | |||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
"model also implements the `evaluate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
@@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE: | |||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
_reduces = { | _reduces = { | ||||
'max': paddle.max, | |||||
'min': paddle.min, | |||||
'mean': paddle.mean, | |||||
'sum': paddle.sum | |||||
"max": paddle.max, | |||||
"min": paddle.min, | |||||
"mean": paddle.mean, | |||||
"sum": paddle.sum | |||||
} | } | ||||
class PaddleDriver(Driver): | class PaddleDriver(Driver): | ||||
@@ -72,7 +72,7 @@ class PaddleDriver(Driver): | |||||
optimizer.clear_grad() | optimizer.clear_grad() | ||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
r""" | r""" | ||||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 | 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 | ||||
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 | 要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 | ||||
@@ -117,24 +117,24 @@ class PaddleDriver(Driver): | |||||
def check_evaluator_mode(self, mode: str): | def check_evaluator_mode(self, mode: str): | ||||
r""" | r""" | ||||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||||
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||||
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 | |||||
我们应当提醒用户这一行为; | 我们应当提醒用户这一行为; | ||||
""" | """ | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | |||||
if not hasattr(model, "evaluate_step"): | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | logger.warning( | ||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||||
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" | "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" | ||||
"'validate_step'.") | |||||
"'evaluate_step'.") | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | ||||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||||
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" | |||||
"'test_step'.") | "'test_step'.") | ||||
@staticmethod | @staticmethod | ||||
@@ -254,24 +254,24 @@ class PaddleDriver(Driver): | |||||
else: | else: | ||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
num_consumed_batches = states.pop("num_consumed_batches") | |||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | ||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) | |||||
if num_consumed_samples_array is not None: | if num_consumed_samples_array is not None: | ||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
try: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||||
else: | |||||
try: | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
pass | |||||
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | |||||
states["sampler_states"] = sampler_states | |||||
# 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
if should_save_model: | if should_save_model: | ||||
@@ -326,7 +326,7 @@ class PaddleDriver(Driver): | |||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
) | ) | ||||
sampler.load_state_dict(states['sampler_states']) | |||||
sampler.load_state_dict(states["sampler_states"]) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | ||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
@@ -355,7 +355,7 @@ class PaddleDriver(Driver): | |||||
return paddle.no_grad | return paddle.no_grad | ||||
@staticmethod | @staticmethod | ||||
def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']): | |||||
def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): | |||||
r""" | r""" | ||||
用来将模型转移到指定的 device 上; | 用来将模型转移到指定的 device 上; | ||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | ||||
@@ -363,7 +363,7 @@ class PaddleDriver(Driver): | |||||
if device is not None: | if device is not None: | ||||
model.to(device) | model.to(device) | ||||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||||
def move_data_to_device(self, batch: "paddle.Tensor"): | |||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | ||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | ||||
@@ -404,7 +404,7 @@ class PaddleDriver(Driver): | |||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | ||||
dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) | dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) | ||||
def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx): | |||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | |||||
r""" | r""" | ||||
对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | ||||
@@ -50,10 +50,10 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._train_step = self.model | self._train_step = self.model | ||||
self._train_signature_fn = model.forward | self._train_signature_fn = model.forward | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | ||||
"implements the `validate_step` method, which we can not call actually, we " | |||||
"will call `forward` function instead of `validate_step` and you should note that.") | |||||
"implements the `evaluate_step` method, which we can not call actually, we " | |||||
"will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
self._validate_step = self.model | self._validate_step = self.model | ||||
self._validate_signature_fn = model.forward | self._validate_signature_fn = model.forward | ||||
@@ -73,8 +73,8 @@ class PaddleSingleDriver(PaddleDriver): | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
self._train_signature_fn = model.forward | self._train_signature_fn = model.forward | ||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
if hasattr(self.model, "evaluate_step"): | |||||
self._validate_step = self.model.evaluate_step | |||||
self._validate_signature_fn = None | self._validate_signature_fn = None | ||||
elif hasattr(self.model, "test_step"): | elif hasattr(self.model, "test_step"): | ||||
self._validate_step = self.model.test_step | self._validate_step = self.model.test_step | ||||
@@ -87,9 +87,9 @@ class PaddleSingleDriver(PaddleDriver): | |||||
if hasattr(self.model, "test_step"): | if hasattr(self.model, "test_step"): | ||||
self._test_step = self.model.test_step | self._test_step = self.model.test_step | ||||
self._test_signature_fn = None | self._test_signature_fn = None | ||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
self._test_signature_fn = self.model.validate_step | |||||
elif hasattr(self.model, "evaluate_step"): | |||||
self._test_step = self.model.evaluate_step | |||||
self._test_signature_fn = self.model.evaluate_step | |||||
else: | else: | ||||
self._test_step = self.model | self._test_step = self.model | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
@@ -108,11 +108,11 @@ class _FleetWrappingModel(Layer): | |||||
self._train_step = self.model | self._train_step = self.model | ||||
self._train_signature_fn = model.forward | self._train_signature_fn = model.forward | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
"model also implements the `evaluate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
self._validate_step = self.model | self._validate_step = self.model | ||||
self._validate_signature_fn = model.forward | self._validate_signature_fn = model.forward | ||||
@@ -131,7 +131,7 @@ class _FleetWrappingModel(Layer): | |||||
self._train_step = model | self._train_step = model | ||||
self._train_signature_fn = model.forward | self._train_signature_fn = model.forward | ||||
if hasattr(model, "validate_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
self._validate_step = model.validate_step | self._validate_step = model.validate_step | ||||
self._validate_signature_fn = None | self._validate_signature_fn = None | ||||
elif hasattr(model, "test_step"): | elif hasattr(model, "test_step"): | ||||
@@ -144,7 +144,7 @@ class _FleetWrappingModel(Layer): | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
self._test_step = model.test_step | self._test_step = model.test_step | ||||
self._test_signature_fn = None | self._test_signature_fn = None | ||||
elif hasattr(model, "validate_step"): | |||||
elif hasattr(model, "evaluate_step"): | |||||
self._test_step = model.validate_step | self._test_step = model.validate_step | ||||
self._test_signature_fn = None | self._test_signature_fn = None | ||||
else: | else: | ||||
@@ -172,9 +172,9 @@ class _FleetWrappingModel(Layer): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif forward_state == ForwardState.PREDICT: | elif forward_state == ForwardState.PREDICT: | ||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||||
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") | |||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | |||||
raise NotImplementedError("You should direct a concrete evaluate_fn.") | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
@@ -4,7 +4,7 @@ import __main__ | |||||
import socket | import socket | ||||
import numpy as np | import numpy as np | ||||
from time import sleep | from time import sleep | ||||
from typing import List, Optional, Union, Dict | |||||
from typing import List, Optional, Union, Dict, Tuple, Callable | |||||
from functools import partial | from functools import partial | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
@@ -21,8 +21,6 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import ( | from fastNLP.core.drivers.torch_driver.utils import ( | ||||
_DDPWrappingModel, | _DDPWrappingModel, | ||||
ForwardState, | |||||
_MODE_PARAMETER, | |||||
reset_seed, | reset_seed, | ||||
replace_sampler, | replace_sampler, | ||||
replace_batch_sampler | replace_batch_sampler | ||||
@@ -158,10 +156,10 @@ class TorchDDPDriver(TorchDriver): | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | ———————————————————————————————————————————————————————————————————————————————————————————————————————— | ||||
3. _DDPWrappingModel 的作用; | 3. _DDPWrappingModel 的作用; | ||||
因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | ||||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | 的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | ||||
forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。 | |||||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | 4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | ||||
@@ -204,37 +202,6 @@ class TorchDDPDriver(TorchDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self.model_device = None | self.model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return step_fn(batch) | |||||
model = model.module | |||||
if hasattr(model, "train_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `train_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `test_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `test_step` and you should note that.") | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._test_signature_fn = model.forward | |||||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
@@ -253,7 +220,6 @@ class TorchDDPDriver(TorchDriver): | |||||
# world_size 表示的就是全局的显卡的数量; | # world_size 表示的就是全局的显卡的数量; | ||||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self._configured = False # 防止重复调用 configure_ddp() 函数使用的 | |||||
self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) | self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) | ||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | ||||
@@ -268,8 +234,8 @@ class TorchDDPDriver(TorchDriver): | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | os.makedirs(name=self.output_from_new_proc, exist_ok=True) | ||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | ||||
# 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_setup = False | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||||
def setup(self): | def setup(self): | ||||
if self._has_setup: | if self._has_setup: | ||||
@@ -341,24 +307,16 @@ class TorchDDPDriver(TorchDriver): | |||||
self._pids = self.tensor_to_numeric(self._pids) | self._pids = self.tensor_to_numeric(self._pids) | ||||
def configure_ddp(self): | def configure_ddp(self): | ||||
if not self._configured and not isinstance(self.model, DistributedDataParallel): | |||||
if not isinstance(self.model, DistributedDataParallel): | |||||
self.model = DistributedDataParallel( | self.model = DistributedDataParallel( | ||||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | ||||
_DDPWrappingModel(self.model), device_ids=[self.model_device.index], | _DDPWrappingModel(self.model), device_ids=[self.model_device.index], | ||||
**self._ddp_kwargs | **self._ddp_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | |||||
self._has_ddpwrapped = True | |||||
def open_subprocess(self): | def open_subprocess(self): | ||||
if self.local_rank == 0: | if self.local_rank == 0: | ||||
# self._consensus_file = Path(tempfile.mkstemp()[1]) | |||||
# self._consensus_file.unlink() | |||||
# Script called as `python a/b/c.py` | # Script called as `python a/b/c.py` | ||||
if __main__.__spec__ is None: # pragma: no-cover | if __main__.__spec__ is None: # pragma: no-cover | ||||
# pull out the commands used to run the script and resolve the abs file path | # pull out the commands used to run the script and resolve the abs file path | ||||
@@ -432,18 +390,39 @@ class TorchDDPDriver(TorchDriver): | |||||
return self._data_device | return self._data_device | ||||
return self.model_device | return self.model_device | ||||
def train_step(self, batch): | |||||
# 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel'; | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
return self._train_step(batch) | |||||
def validate_step(self, batch): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
return self._validate_step(batch) | |||||
def test_step(self, batch): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
return self._test_step(batch) | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
if self._has_ddpwrapped: | |||||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | |||||
wo_auto_param_call=self.wo_auto_param_call) | |||||
else: | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return fn(batch) | |||||
def get_model_call_fn(self, fn: str) -> Tuple: | |||||
model = self.unwrap_model() | |||||
if self._has_ddpwrapped: | |||||
if hasattr(model, fn): | |||||
fn = getattr(model, fn) | |||||
if not callable(fn): | |||||
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||||
return fn, None | |||||
elif fn in {"train_step", "evaluate_step"}: | |||||
return model, model.forward | |||||
else: | |||||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||||
else: | |||||
if hasattr(model, fn): | |||||
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " | |||||
f"the `{fn}` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
elif fn not in {"train_step", "evaluate_step"}: | |||||
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " | |||||
"`DistributedDataParallel` model, which means that we will only call model.forward " | |||||
"function when we are in forward propagation.") | |||||
return self.model, model.forward | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
@@ -1,10 +1,11 @@ | |||||
import os | import os | ||||
from typing import Dict, Union | |||||
from typing import Dict, Union, Callable, Tuple, Optional | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.nn import DataParallel | from torch.nn import DataParallel | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -13,7 +14,9 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | ||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -42,84 +45,42 @@ class TorchSingleDriver(TorchDriver): | |||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self.world_size = 1 | self.world_size = 1 | ||||
if isinstance(model, DataParallel): | |||||
model = self.unwrap_model() | |||||
if hasattr(model, "train_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `train_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
self._train_step = self.model | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
self._validate_step = self.model | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `test_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `test_step` and you should note that.") | |||||
self._test_step = self.model | |||||
self._test_signature_fn = model.forward | |||||
else: | |||||
if hasattr(self.model, "train_step"): | |||||
self._train_step = self.model.train_step | |||||
self._train_signature_fn = None | |||||
else: | |||||
self._train_step = self.model | |||||
# 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的; | |||||
model = self.unwrap_model() | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
self._validate_signature_fn = None | |||||
elif hasattr(self.model, "test_step"): | |||||
self._validate_step = self.model.test_step | |||||
self._validate_signature_fn = self.model.test_step | |||||
else: | |||||
self._validate_step = self.model | |||||
model = self.unwrap_model() | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(self.model, "test_step"): | |||||
self._test_step = self.model.test_step | |||||
self._test_signature_fn = None | |||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
self._test_signature_fn = self.model.validate_step | |||||
else: | |||||
self._test_step = self.model | |||||
model = self.unwrap_model() | |||||
self._test_signature_fn = model.forward | |||||
def setup(self): | def setup(self): | ||||
if self.model_device is not None: | if self.model_device is not None: | ||||
self.model.to(self.model_device) | self.model.to(self.model_device) | ||||
def train_step(self, batch) -> Dict: | |||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | else: | ||||
return self._train_step(batch) | |||||
return fn(batch) | |||||
def validate_step(self, batch) -> Dict: | |||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
def get_model_call_fn(self, fn: str) -> Tuple: | |||||
if isinstance(self.model, DataParallel): | |||||
model = self.unwrap_model() | |||||
if hasattr(model, fn): | |||||
logger.warning("Notice your model is a `DataParallel` model. And your model also implements the " | |||||
f"`{fn}` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
def test_step(self, batch) -> Dict: | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
elif fn not in {"train_step", "evaluate_step"}: | |||||
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " | |||||
f"`DataParallel` model, which means that we will only call model.forward function " | |||||
f"when we are in forward propagation.") | |||||
return self.model, model.forward | |||||
else: | else: | ||||
return self._test_step(batch) | |||||
if hasattr(self.model, fn): | |||||
fn = getattr(self.model, fn) | |||||
if not callable(fn): | |||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||||
return fn, None | |||||
elif fn in {"train_step", "evaluate_step"}: | |||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
return self.model, self.model.forward | |||||
else: | |||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
@@ -140,12 +101,18 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
if isinstance(args.sampler, TorchRandomSampler): | |||||
# 如果本来就是随机的,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -81,7 +81,7 @@ class TorchDriver(Driver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
if is_train: | if is_train: | ||||
if not isinstance(dataloader, DataLoader): | if not isinstance(dataloader, DataLoader): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") | ||||
@@ -108,23 +108,6 @@ class TorchDriver(Driver): | |||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | ||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
def check_evaluator_mode(self, mode: str): | |||||
model = self.unwrap_model() | |||||
if mode == "validate": | |||||
if not hasattr(model, "validate_step"): | |||||
if hasattr(model, "test_step"): | |||||
logger.warning_once( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||||
"'validate_step'.") | |||||
else: | |||||
if not hasattr(model, "test_step"): | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||||
"'test_step'.") | |||||
@staticmethod | @staticmethod | ||||
def tensor_to_numeric(tensor, reduce=None): | def tensor_to_numeric(tensor, reduce=None): | ||||
if tensor is None: | if tensor is None: | ||||
@@ -216,6 +199,7 @@ class TorchDriver(Driver): | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | num_consumed_batches = sampler_states['num_consumed_samples'] | ||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | ||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | ||||
states['sampler_states'] = sampler_states | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
@@ -90,14 +90,11 @@ class ForwardState(IntEnum): | |||||
PREDICT = 3 | PREDICT = 3 | ||||
_MODE_PARAMETER = "_forward_state" | |||||
class _DDPWrappingModel(Module): | class _DDPWrappingModel(Module): | ||||
""" | """ | ||||
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | ||||
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | ||||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等; | |||||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; | |||||
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | ||||
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | ||||
@@ -109,60 +106,18 @@ class _DDPWrappingModel(Module): | |||||
super(_DDPWrappingModel, self).__init__() | super(_DDPWrappingModel, self).__init__() | ||||
self.model = model | self.model = model | ||||
if hasattr(model, "train_step"): | |||||
self._train_step = model.train_step | |||||
self._train_signature_fn = None | |||||
else: | |||||
self._train_step = model | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
self._validate_step = model.validate_step | |||||
self._validate_signature_fn = None | |||||
elif hasattr(model, "test_step"): | |||||
self._validate_step = model.test_step | |||||
self._validate_signature_fn = None | |||||
else: | |||||
self._validate_step = model | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
self._test_step = model.test_step | |||||
self._test_signature_fn = None | |||||
elif hasattr(model, "validate_step"): | |||||
self._test_step = model.validate_step | |||||
self._test_signature_fn = None | |||||
else: | |||||
self._test_step = model | |||||
self._test_signature_fn = model.forward | |||||
def forward(self, batch, **kwargs) -> Dict: | def forward(self, batch, **kwargs) -> Dict: | ||||
""" | """ | ||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | ||||
""" | """ | ||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
fn = kwargs.pop("fastnlp_fn") | |||||
signature_fn = kwargs.pop("fastnlp_signature_fn") | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | wo_auto_param_call = kwargs.pop("wo_auto_param_call") | ||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
else: | |||||
return self._train_step(batch) | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
else: | |||||
return self._test_step(batch) | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | |||||
return fn(batch) | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
@@ -55,8 +55,8 @@ class TorchPaddleDriver(Driver): | |||||
self._train_step = self.model | self._train_step = self.model | ||||
self._train_signature_fn = self.model.forward | self._train_signature_fn = self.model.forward | ||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
if hasattr(self.model, "evaluate_step"): | |||||
self._validate_step = self.model.evaluate_step | |||||
self._validate_signature_fn = None | self._validate_signature_fn = None | ||||
elif hasattr(self.model, "test_step"): | elif hasattr(self.model, "test_step"): | ||||
self._validate_step = self.model.test_step | self._validate_step = self.model.test_step | ||||
@@ -68,8 +68,8 @@ class TorchPaddleDriver(Driver): | |||||
if hasattr(self.model, "test_step"): | if hasattr(self.model, "test_step"): | ||||
self._test_step = self.model.test_step | self._test_step = self.model.test_step | ||||
self._test_signature_fn = None | self._test_signature_fn = None | ||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
elif hasattr(self.model, "evaluate_step"): | |||||
self._test_step = self.model.evaluate_step | |||||
self._test_signature_fn = self.model.forward | self._test_signature_fn = self.model.forward | ||||
else: | else: | ||||
self._test_step = self.model | self._test_step = self.model | ||||
@@ -81,7 +81,7 @@ class TorchPaddleDriver(Driver): | |||||
self.model.to(self.model_device) | self.model.to(self.model_device) | ||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
if is_train: | if is_train: | ||||
if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): | if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") | ||||
@@ -211,9 +211,9 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] | |||||
raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") | raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") | ||||
if not isinstance(mode, str): | if not isinstance(mode, str): | ||||
raise TypeError("Parameter 'mode' can only be `str` type.") | |||||
raise TypeError("Parameter 'evaluate_fn' can only be `str` type.") | |||||
if mode not in {"w", "a"}: | if mode not in {"w", "a"}: | ||||
raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').") | |||||
raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').") | |||||
for h in _logger.handlers: | for h in _logger.handlers: | ||||
if isinstance(h, logging.FileHandler): | if isinstance(h, logging.FileHandler): | ||||
@@ -230,7 +230,7 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] | |||||
dirname = os.path.abspath(os.path.dirname(path)) | dirname = os.path.abspath(os.path.dirname(path)) | ||||
os.makedirs(dirname, exist_ok=True) | os.makedirs(dirname, exist_ok=True) | ||||
# 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 | |||||
# 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 | |||||
# 覆盖掉原文件,而是会接着上一次的 log 继续添加; | # 覆盖掉原文件,而是会接着上一次的 log 继续添加; | ||||
# 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; | # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; | ||||
if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: | if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: | ||||
@@ -164,7 +164,7 @@ def _get_keys(args:List[Dict]) -> List[List[str]]: | |||||
return _provided_keys | return _provided_keys | ||||
def _get_fun_msg(fn)->str: | |||||
def _get_fun_msg(fn, with_fp=True)->str: | |||||
""" | """ | ||||
获取函数的基本信息,帮助报错。 | 获取函数的基本信息,帮助报错。 | ||||
ex: | ex: | ||||
@@ -172,6 +172,7 @@ def _get_fun_msg(fn)->str: | |||||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | # `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | ||||
:param callable fn: | :param callable fn: | ||||
:param with_fp: 是否包含函数所在的文件信息。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(fn, functools.partial): | if isinstance(fn, functools.partial): | ||||
@@ -180,9 +181,12 @@ def _get_fun_msg(fn)->str: | |||||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | fn_name = fn.__qualname__ + str(inspect.signature(fn)) | ||||
except: | except: | ||||
fn_name = str(fn) | fn_name = str(fn) | ||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
if with_fp: | |||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
fp = '' | |||||
else: | |||||
fp = '' | fp = '' | ||||
msg = f'`{fn_name}`' + fp | msg = f'`{fn_name}`' + fp | ||||
return msg | return msg | ||||
@@ -37,7 +37,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -63,7 +63,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2( | |||||
device=4, | device=4, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -473,12 +473,12 @@ def test_trainer_checkpoint_callback_1( | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.skip("Skip transformers test for now.") | |||||
def test_trainer_checkpoint_callback_2( | def test_trainer_checkpoint_callback_2( | ||||
driver, | driver, | ||||
device, | device, | ||||
version | version | ||||
): | ): | ||||
pytest.skip("Skip transformers test for now.") | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | path = Path.cwd().joinpath(f"test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
@@ -626,7 +626,7 @@ def test_trainer_checkpoint_callback_2( | |||||
train_dataloader=test_bert_dataloader_train, | train_dataloader=test_bert_dataloader_train, | ||||
optimizers=test_bert_optimizers, | optimizers=test_bert_optimizers, | ||||
validate_dataloaders=test_bert_dataloader_validate, | |||||
evaluate_dataloaders=test_bert_dataloader_validate, | |||||
input_mapping=bert_input_mapping, | input_mapping=bert_input_mapping, | ||||
output_mapping=bert_output_mapping, | output_mapping=bert_output_mapping, | ||||
metrics={"acc": acc}, | metrics={"acc": acc}, | ||||
@@ -700,7 +700,7 @@ def test_trainer_checkpoint_callback_2( | |||||
train_dataloader=test_bert_dataloader_train, | train_dataloader=test_bert_dataloader_train, | ||||
optimizers=test_bert_optimizers, | optimizers=test_bert_optimizers, | ||||
validate_dataloaders=test_bert_dataloader_validate, | |||||
evaluate_dataloaders=test_bert_dataloader_validate, | |||||
input_mapping=bert_input_mapping, | input_mapping=bert_input_mapping, | ||||
output_mapping=bert_output_mapping, | output_mapping=bert_output_mapping, | ||||
metrics={"acc": acc}, | metrics={"acc": acc}, | ||||
@@ -40,7 +40,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -66,7 +66,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -92,7 +92,7 @@ def test_load_best_model_callback( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -105,7 +105,7 @@ def test_load_best_model_callback( | |||||
driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | ||||
evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | ||||
dataloaders={'dl1': model_and_optimizers.validate_dataloaders}, | |||||
dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, | |||||
metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
progress_bar='rich', use_dist_sampler=False) | progress_bar='rich', use_dist_sampler=False) | ||||
@@ -75,7 +75,7 @@ _dataloader = DataLoader( | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
train_dataloader = _dataloader | train_dataloader = _dataloader | ||||
validate_dataloaders = _dataloader | |||||
evaluate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | metrics = {"acc": Accuracy()} | ||||
@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
validate_dataloaders=validate_dataloaders, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | metrics=metrics, | ||||
n_epochs=2, | n_epochs=2, | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -63,7 +63,7 @@ _dataloader = DataLoader( | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
train_dataloader = _dataloader | train_dataloader = _dataloader | ||||
validate_dataloaders = _dataloader | |||||
evaluate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | metrics = {"acc": Accuracy()} | ||||
@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
validate_dataloaders=validate_dataloaders, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | metrics=metrics, | ||||
n_epochs=2, | n_epochs=2, | ||||
@@ -30,7 +30,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -57,7 +57,7 @@ def model_and_optimizers(): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -82,7 +82,7 @@ def test_trainer_event_trigger( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -64,8 +64,8 @@ def test_trainer_fleet( | |||||
device=device, | device=device, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
evaluate_dataloaders=validate_dataloaders, | |||||
evaluate_every=validate_every, | |||||
input_mapping=None, | input_mapping=None, | ||||
output_mapping=None, | output_mapping=None, | ||||
metrics=metrics, | metrics=metrics, | ||||
@@ -70,8 +70,8 @@ def test_trainer_fleet( | |||||
device=device, | device=device, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
evaluate_dataloaders=validate_dataloaders, | |||||
evaluate_every=validate_every, | |||||
input_mapping=None, | input_mapping=None, | ||||
output_mapping=None, | output_mapping=None, | ||||
metrics=metrics, | metrics=metrics, | ||||
@@ -1,19 +1,20 @@ | |||||
import pytest | import pytest | ||||
import os | import os | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
from fastNLP.core.callbacks.progress_callback import RichCallback | from fastNLP.core.callbacks.progress_callback import RichCallback | ||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | ||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -48,64 +49,31 @@ class TrainerParameters: | |||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
# @pytest.fixture(params=[0], autouse=True) | |||||
# def model_and_optimizers(request): | |||||
# """ | |||||
# 初始化单卡模式的模型和优化器 | |||||
# """ | |||||
# trainer_params = TrainerParameters() | |||||
# print(paddle.device.get_device()) | |||||
# if request.param == 0: | |||||
# trainer_params.model = PaddleNormalModel_Classification( | |||||
# num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
# ) | |||||
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
# train_dataloader = DataLoader( | |||||
# dataset=PaddleDataset_MNIST("train"), | |||||
# batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
# shuffle=True | |||||
# ) | |||||
# val_dataloader = DataLoader( | |||||
# dataset=PaddleDataset_MNIST(mode="test"), | |||||
# batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
# shuffle=True | |||||
# ) | |||||
# trainer_params.train_dataloader = train_dataloader | |||||
# trainer_params.validate_dataloaders = val_dataloader | |||||
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
# trainer_params.metrics = {"acc": Accuracy()} | |||||
# return trainer_params | |||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) | |||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | ||||
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_paddle( | def test_trainer_paddle( | ||||
# model_and_optimizers: TrainerParameters, | |||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | callbacks, | ||||
n_epochs=15, | |||||
n_epochs=2, | |||||
): | ): | ||||
trainer_params = TrainerParameters() | trainer_params = TrainerParameters() | ||||
trainer_params.model = PaddleNormalModel_Classification( | |||||
trainer_params.model = PaddleNormalModel_Classification_1( | |||||
num_labels=MNISTTrainPaddleConfig.num_labels, | num_labels=MNISTTrainPaddleConfig.num_labels, | ||||
feature_dimension=MNISTTrainPaddleConfig.feature_dimension | feature_dimension=MNISTTrainPaddleConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | ||||
train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
dataset=PaddleDataset_MNIST("train"), | |||||
dataset=PaddleRandomMaxDataset(6400, 10), | |||||
batch_size=MNISTTrainPaddleConfig.batch_size, | batch_size=MNISTTrainPaddleConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
dataset=PaddleDataset_MNIST(mode="test"), | |||||
dataset=PaddleRandomMaxDataset(1000, 10), | |||||
batch_size=MNISTTrainPaddleConfig.batch_size, | batch_size=MNISTTrainPaddleConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
@@ -113,39 +81,19 @@ def test_trainer_paddle( | |||||
trainer_params.validate_dataloaders = val_dataloader | trainer_params.validate_dataloaders = val_dataloader | ||||
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | ||||
trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | ||||
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as exc: | |||||
trainer = Trainer( | |||||
model=trainer_params.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=trainer_params.optimizers, | |||||
train_dataloader=trainer_params.train_dataloader, | |||||
validate_dataloaders=trainer_params.validate_dataloaders, | |||||
validate_every=trainer_params.validate_every, | |||||
input_mapping=trainer_params.input_mapping, | |||||
output_mapping=trainer_params.output_mapping, | |||||
metrics=trainer_params.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
assert exc.value.code == 0 | |||||
return | |||||
else: | |||||
trainer = Trainer( | |||||
model=trainer_params.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=trainer_params.optimizers, | |||||
train_dataloader=trainer_params.train_dataloader, | |||||
validate_dataloaders=trainer_params.validate_dataloaders, | |||||
validate_every=trainer_params.validate_every, | |||||
input_mapping=trainer_params.input_mapping, | |||||
output_mapping=trainer_params.output_mapping, | |||||
metrics=trainer_params.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() | |||||
trainer = Trainer( | |||||
model=trainer_params.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=trainer_params.optimizers, | |||||
train_dataloader=trainer_params.train_dataloader, | |||||
validate_dataloaders=trainer_params.validate_dataloaders, | |||||
validate_every=trainer_params.validate_every, | |||||
input_mapping=trainer_params.input_mapping, | |||||
output_mapping=trainer_params.output_mapping, | |||||
metrics=trainer_params.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() |
@@ -43,7 +43,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -71,7 +71,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
elif request.param == 1: | elif request.param == 1: | ||||
@@ -91,23 +91,23 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | ||||
@pytest.mark.parametrize("validate_every", [-3]) | |||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | callbacks, | ||||
validate_every, | |||||
evaluate_every, | |||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -116,11 +116,11 @@ def test_trainer_torch_with_evaluator( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
validate_every=validate_every, | |||||
evaluate_every=evaluate_every, | |||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -193,14 +193,14 @@ def test_trainer_validate_every( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
output_from_new_proc="all", | output_from_new_proc="all", | ||||
validate_every=validate_every | |||||
evaluate_every=validate_every | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
@@ -38,7 +38,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -65,7 +65,7 @@ def model_and_optimizers(request): | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | batch_size=NormalClassificationTrainTorchConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.validate_dataloaders = None | |||||
trainer_params.evaluate_dataloaders = None | |||||
trainer_params.input_mapping = None | trainer_params.input_mapping = None | ||||
trainer_params.output_mapping = None | trainer_params.output_mapping = None | ||||
@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -267,7 +267,7 @@ def test_trainer_on_exception( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -1,21 +1,35 @@ | |||||
from dataclasses import replace | |||||
import pytest | import pytest | ||||
import os | import os | ||||
import numpy as np | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
set_env_on_import_paddle() | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
from fastNLP.core.samplers import ( | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
BucketedBatchSampler, | |||||
UnrepeatedRandomSampler, | |||||
UnrepeatedSequentialSampler, | |||||
) | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
import paddle | import paddle | ||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
from paddle.io import DataLoader | |||||
from paddle.io import DataLoader, BatchSampler | |||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
def generate_driver(num_labels, feature_dimension): | |||||
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
driver.setup() | |||||
return driver | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||
@@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@magic_argv_env_context | |||||
def test_move_data_to_device(): | |||||
""" | |||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
就不重复测试了 | |||||
""" | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
driver.move_data_to_device(paddle.rand((32, 64))) | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_distributed(): | |||||
print(os.getenv("CUDA_VISIBLE_DEVICES")) | |||||
print(paddle.device.get_device()) | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
output_from_new_proc='all' | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
assert driver.is_distributed() == True | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(): | |||||
class TestFleetDriverFunction: | |||||
""" | """ | ||||
测试能否运行 | |||||
测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||||
""" | """ | ||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
res = driver.get_no_sync_context() | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_global_zero(): | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
driver.is_global_zero() | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_unwrap_model(): | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
driver.unwrap_model() | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_local_rank(): | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
driver.get_local_rank() | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize( | |||||
"dist_sampler", | |||||
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"reproducible", | |||||
[True, False] | |||||
) | |||||
def test_replace_sampler(dist_sampler, reproducible): | |||||
""" | |||||
测试replace_sampler | |||||
""" | |||||
try: | |||||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=[0,1], | |||||
) | |||||
driver.set_optimizers(paddle_opt) | |||||
# 区分launch和子进程setup的时候 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as e: | |||||
driver.setup() | |||||
assert e.value.code == 0 | |||||
return | |||||
else: | |||||
driver.setup() | |||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||||
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||||
finally: | |||||
synchronize_safe_rm("log") | |||||
dist.barrier() | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | |||||
def test_move_data_to_device(self): | |||||
""" | |||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
就不重复测试了 | |||||
""" | |||||
self.driver.move_data_to_device(paddle.rand((32, 64))) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_distributed(self): | |||||
""" | |||||
测试 is_distributed 函数 | |||||
""" | |||||
assert self.driver.is_distributed() == True | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(self): | |||||
""" | |||||
测试 get_no_sync_context 函数 | |||||
""" | |||||
res = self.driver.get_no_sync_context() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_global_zero(self): | |||||
""" | |||||
测试 is_global_zero 函数 | |||||
""" | |||||
self.driver.is_global_zero() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试 unwrap_model 函数 | |||||
""" | |||||
self.driver.unwrap_model() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_local_rank(self): | |||||
""" | |||||
测试 get_local_rank 函数 | |||||
""" | |||||
self.driver.get_local_rank() | |||||
dist.barrier() | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||
# 测试单机多卡的训练情况 | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@magic_argv_env_context | |||||
class SingleMachineMultiGPUTrainingTestCase: | |||||
class TestSetDistReproDataloader: | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
def setup_method(self): | |||||
self.dataset = PaddleNormalDataset(20) | |||||
""" | """ | ||||
测试在单机多卡上使用PaddleFleetDriver进行训练。 | |||||
分布式训练用pytest会有些混乱 | |||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||||
此时对应 driver.load 中的情况 | |||||
""" | """ | ||||
def test_case1(self): | |||||
gpus = [0, 1] | |||||
lr = 0.0003 | |||||
epochs = 20 | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler is batch_sampler | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
sampler = RandomSampler(self.dataset, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is sampler | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||||
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||||
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||||
是否重新实例化 dataloader | |||||
""" | |||||
paddle_model = PaddleNormalModel_Classification() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
with pytest.raises(RuntimeError): | |||||
# 应当抛出 RuntimeError | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||||
时的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
pad=True | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
self.check_distributed_sampler(dataloader.batch_sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.drop_last == False | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | |||||
dist.barrier() | |||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) | |||||
""" | |||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
train_dataset = PaddleDataset_MNIST("train") | |||||
test_dataset = PaddleDataset_MNIST("test") | |||||
loss_func = paddle.nn.CrossEntropyLoss() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||||
的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
dist.barrier() | |||||
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) | |||||
""" | |||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
driver = PaddleFleetDriver( | |||||
model=paddle_model, | |||||
parallel_device=gpus, | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | ) | ||||
driver.set_optimizers(paddle_opt) | |||||
dataloader = driver.set_dist_repro_dataloader(dataloader, ) | |||||
driver.setup() | |||||
# 检查model_device | |||||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | |||||
driver.barrier() | |||||
driver.zero_grad() | |||||
current_epoch_idx = 0 | |||||
while current_epoch_idx < epochs: | |||||
epoch_loss, batch = 0, 0 | |||||
driver.set_model_mode("train") | |||||
driver.set_sampler_epoch(dataloader, current_epoch_idx) | |||||
for batch, (img, label) in enumerate(dataloader): | |||||
img = paddle.to_tensor(img) | |||||
out = driver.train_step(img) | |||||
label + 1 | |||||
loss = loss_func(out, label) | |||||
epoch_loss += loss.item() | |||||
if batch % 50 == 0: | |||||
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) | |||||
driver.backward(loss) | |||||
driver.step() | |||||
driver.zero_grad() | |||||
driver.barrier() | |||||
current_epoch_idx += 1 | |||||
# test | |||||
correct = 0 | |||||
driver.set_model_mode("eval") | |||||
for img, label in test_dataset: | |||||
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) | |||||
out = driver.test_step(img) | |||||
res = paddle.nn.functional.softmax(out).argmax().item() | |||||
label = label.item() | |||||
if res == label: | |||||
correct += 1 | |||||
print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||||
的表现 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
def check_distributed_sampler(self, sampler): | |||||
""" | |||||
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||||
""" | |||||
assert sampler.num_replicas == dist.get_world_size() | |||||
assert sampler.rank == dist.get_rank() | |||||
if not isinstance(sampler, UnrepeatedSampler): | |||||
assert sampler.pad == True | |||||
@@ -224,7 +224,6 @@ class TestSetDistReproDataloder: | |||||
""" | """ | ||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = PaddleNormalDataset(20) | self.dataset = PaddleNormalDataset(20) | ||||
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
model = PaddleNormalModel_Classification_1(10, 32) | model = PaddleNormalModel_Classification_1(10, 32) | ||||
self.driver = PaddleSingleDriver(model, device="cpu") | self.driver = PaddleSingleDriver(model, device="cpu") | ||||
@@ -233,55 +232,59 @@ class TestSetDistReproDataloder: | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | ||||
当dist为字符串时,此时应该返回原来的 dataloader | 当dist为字符串时,此时应该返回原来的 dataloader | ||||
""" | """ | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is self.dataloader | |||||
assert replaced_loader is dataloader | |||||
def test_set_dist_repro_dataloader_with_reproducible_true(self): | def test_set_dist_repro_dataloader_with_reproducible_true(self): | ||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | ||||
""" | """ | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | ||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == self.dataloader.drop_last | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | ||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | ||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | ||||
assert replaced_loader.batch_sampler is dist | assert replaced_loader.batch_sampler is dist | ||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_sampler(self): | def test_set_dist_repro_dataloader_with_dist_sampler(self): | ||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | ||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
dist = RandomSampler(self.dataset, shuffle=True) | dist = RandomSampler(self.dataset, shuffle=True) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | assert replaced_loader.batch_sampler.sampler is dist | ||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | ||||
""" | """ | ||||
@@ -295,11 +298,12 @@ class TestSetDistReproDataloder: | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | ||||
""" | """ | ||||
@@ -316,34 +320,52 @@ class TestSetDistReproDataloder: | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == 2 | assert replaced_loader.batch_sampler.batch_size == 2 | ||||
assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | ||||
""" | """ | ||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | ||||
""" | """ | ||||
# 迭代两个 batch | # 迭代两个 batch | ||||
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | already_seen_idx = set() | ||||
for idx, batch in replaced_loader: | |||||
already_seen_idx.update(batch) | |||||
if idx >= 1: | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | break | ||||
already_seen_idx.update(batch) | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | ||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
print(sampler_states["data_idx"]) | |||||
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
import time | |||||
time.sleep(5) | |||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | ||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
if num_consumed_samples_array is not None: | |||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||||
else: | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
replaced_loader.batch_sampler.load_state_dict(sampler_states) | replaced_loader.batch_sampler.load_state_dict(sampler_states) | ||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
if num_consumed_samples_array is not None: | |||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||||
else: | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
replaced_loader.batch_sampler.sampler.set_epoch(0) | |||||
for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
left_idxes.update(batch) | left_idxes.update(batch) | ||||
@@ -401,12 +423,12 @@ class TestPaddleDriverFunctions: | |||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | 测试is_train参数为True时,_check_dataloader_legality函数的表现 | ||||
""" | """ | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建torch的dataloader | # 创建torch的dataloader | ||||
dataloader = torch.utils.data.DataLoader( | dataloader = torch.utils.data.DataLoader( | ||||
@@ -414,7 +436,7 @@ class TestPaddleDriverFunctions: | |||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
) | ) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
""" | """ | ||||
@@ -425,7 +447,7 @@ class TestPaddleDriverFunctions: | |||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | "train": paddle.io.DataLoader(PaddleNormalDataset()), | ||||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | "test":paddle.io.DataLoader(PaddleNormalDataset()) | ||||
} | } | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = { | dataloader = { | ||||
@@ -433,12 +455,12 @@ class TestPaddleDriverFunctions: | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
} | } | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是dict,应该报错 | # 传入的不是dict,应该报错 | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建torch的dataloader | # 创建torch的dataloader | ||||
train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
@@ -451,7 +473,7 @@ class TestPaddleDriverFunctions: | |||||
) | ) | ||||
dataloader = {"train": train_loader, "test": test_loader} | dataloader = {"train": train_loader, "test": test_loader} | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
""" | """ | ||||
@@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: | |||||
def test_get_fun_msg(): | def test_get_fun_msg(): | ||||
# 测试运行 | |||||
def demo(x): | def demo(x): | ||||
pass | pass | ||||
@@ -1,3 +1,6 @@ | |||||
import numpy as np | |||||
class NormalIterator: | class NormalIterator: | ||||
def __init__(self, num_of_data=1000): | def __init__(self, num_of_data=1000): | ||||
self._num_of_data = num_of_data | self._num_of_data = num_of_data | ||||
@@ -15,4 +18,15 @@ class NormalIterator: | |||||
return self._data | return self._data | ||||
def __len__(self): | def __len__(self): | ||||
return self._num_of_data | |||||
return self._num_of_data | |||||
class RandomDataset: | |||||
def __init__(self, num_data=10): | |||||
self.data = np.random.rand(num_data) | |||||
def __len__(self): | |||||
return len(self.data) | |||||
def __getitem__(self, item): | |||||
return self.data[item] |
@@ -28,7 +28,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||||
x = self(x) | x = self(x) | ||||
return {"loss": self.loss_fn(x, y)} | return {"loss": self.loss_fn(x, y)} | ||||
def validate_step(self, x, y): | |||||
def evaluate_step(self, x, y): | |||||
""" | """ | ||||
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; | 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; | ||||
""" | """ | ||||