@@ -10,6 +10,7 @@ from .utils import _get_monitor_value | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
class Callback: | |||
@@ -32,100 +33,225 @@ class Callback: | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
r""" | |||
在 '预跑'检测 开始后会被触发; | |||
:param trainer: | |||
:param sanity_check_res: 预跑的 evaluate 结果 | |||
:return: | |||
""" | |||
pass | |||
def on_train_begin(self, trainer): | |||
r""" | |||
在训练开始前会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_end(self, trainer): | |||
r""" | |||
在训练完成后会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_epoch_begin(self, trainer): | |||
r""" | |||
在训练过程中的每一个 epoch 开始前会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_epoch_end(self, trainer): | |||
r""" | |||
在训练过程中的每一个 epoch 完成后会被触发; | |||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_fetch_data_begin(self, trainer): | |||
r""" | |||
在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||
在训练过程中准备取出下一个 batch 的数据时触发 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_fetch_data_end(self, trainer): | |||
r""" | |||
在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||
在训练过程中拿到当前的 batch 数据后会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_batch_begin(self, trainer, batch, indices=None): | |||
def on_train_batch_begin(self, trainer, batch, indices): | |||
r""" | |||
在训练过程中开始具体的一个 batch 前会被触发; | |||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
:param trainer: `fastNLP.Trainer` | |||
:param batch: 当前正在运行的一个 batch; | |||
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||
""" | |||
pass | |||
def on_train_batch_end(self, trainer): | |||
""" | |||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
执行。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_exception(self, trainer, exception): | |||
""" | |||
在训练过程遇到异常时调用。 | |||
:param trainer: | |||
:param exception: 遭遇的异常。 | |||
:return: | |||
""" | |||
pass | |||
def on_save_model(self, trainer): | |||
""" | |||
当将要保存模型时调用,此刻模型还未保存。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_load_model(self, trainer): | |||
""" | |||
当将要加载模型时调用,此刻模型还未加载。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||
工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
r""" | |||
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||
的返回值。 | |||
:param trainer: | |||
:param states: | |||
:return: | |||
""" | |||
pass | |||
def on_before_backward(self, trainer, outputs): | |||
""" | |||
在 backward 前执行。 | |||
:param trainer: | |||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||
:return: | |||
""" | |||
pass | |||
def on_after_backward(self, trainer): | |||
""" | |||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
""" | |||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
""" | |||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
""" | |||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
""" | |||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_validate_begin(self, trainer): | |||
""" | |||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_validate_end(self, trainer, results): | |||
""" | |||
结束 validate 时调用,并把 validate 的结果传入。 | |||
:param trainer: | |||
:param results: | |||
:return: | |||
""" | |||
pass | |||
@property | |||
def callback_name(self): | |||
""" | |||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||
:return: | |||
""" | |||
return self.__class__.__name__ | |||
@@ -174,7 +300,11 @@ class HasMonitorCallback(Callback): | |||
self.must_have_moinitor = must_have_monitor | |||
def set_monitor(self, monitor, larger_better): | |||
self.monitor = str(monitor) if monitor is not None else None | |||
if callable(monitor): # 检查是否能够接受一个参数 | |||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||
self.monitor = monitor | |||
else: | |||
self.monitor = str(monitor) if monitor is not None else None | |||
self.larger_better = bool(larger_better) | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
@@ -197,24 +327,33 @@ class HasMonitorCallback(Callback): | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
def get_monitor_value(self, results:Dict)->float: | |||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||
""" | |||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||
:param results: | |||
:return: | |||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||
""" | |||
if len(results)==0: | |||
return 0 | |||
return None | |||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||
logger.warning( | |||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
if monitor_value is None: | |||
return monitor_value | |||
# 第一次运行 | |||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
# 检测到此次和上次不同。 | |||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||
f"customized monitor function when the evaluation results are varying between validation.") | |||
self._real_monitor = use_monitor | |||
return monitor_value | |||
@@ -222,14 +361,33 @@ class HasMonitorCallback(Callback): | |||
""" | |||
检测 monitor_value 是否是更好的 | |||
:param monitor_value: | |||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
:return: | |||
""" | |||
if monitor_value is None: | |||
return False | |||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||
if keep_if_better and better: | |||
self.monitor_value = monitor_value | |||
return better | |||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||
""" | |||
传入的两个值中,是否monitor_value1的结果更好。 | |||
:param monitor_value1: | |||
:param monitor_value2: | |||
:return: | |||
""" | |||
if monitor_value1 is None and monitor_value2 is None: | |||
return True | |||
if monitor_value1 is None: | |||
return False | |||
if monitor_value2 is None: | |||
return True | |||
better = False | |||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
(not self.larger_better and monitor_value < self.monitor_value): | |||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
(not self.larger_better and monitor_value1 < monitor_value2): | |||
better = True | |||
if keep_if_better: | |||
self.monitor_value = monitor_value | |||
return better |
@@ -74,28 +74,30 @@ class EventEnum(_SingleEventState, Enum): | |||
@unique | |||
class Events(EventEnum): | |||
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" | |||
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" | |||
ON_SANITY_CHECK_END = "on_sanity_check_end" | |||
ON_TRAIN_BEGIN = "on_train_begin" | |||
ON_TRAIN_END = "on_train_end" | |||
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" | |||
ON_TRAIN_EPOCH_END = "on_train_epoch_end" | |||
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" | |||
ON_FETCH_DATA_END = "on_fetch_data_end" | |||
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" | |||
ON_TRAIN_BATCH_END = "on_train_batch_end" | |||
ON_EXCEPTION = "on_exception" | |||
ON_SAVE_MODEL = "on_save_model" | |||
ON_LOAD_MODEL = "on_load_model" | |||
ON_SAVE_CHECKPOINT = "on_save_checkpoint" | |||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | |||
ON_BEFORE_BACKWARD = "on_before_backward" | |||
ON_AFTER_BACKWARD = "on_after_backward" | |||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | |||
ON_VALIDATE_BEGIN = "on_validate_begin" | |||
ON_VALIDATE_END = "on_validate_end" | |||
on_after_trainer_initialized = "on_after_trainer_initialized" | |||
on_sanity_check_begin = "on_sanity_check_begin" | |||
on_sanity_check_end = "on_sanity_check_end" | |||
on_train_begin = "on_train_begin" | |||
on_train_end = "on_train_end" | |||
on_train_epoch_begin = "on_train_epoch_begin" | |||
on_train_epoch_end = "on_train_epoch_end" | |||
on_fetch_data_begin = "on_fetch_data_begin" | |||
on_fetch_data_end = "on_fetch_data_end" | |||
on_train_batch_begin = "on_train_batch_begin" | |||
on_train_batch_end = "on_train_batch_end" | |||
on_exception = "on_exception" | |||
on_save_model = "on_save_model" | |||
on_load_model = "on_load_model" | |||
on_save_checkpoint = "on_save_checkpoint" | |||
on_load_checkpoint = "on_load_checkpoint" | |||
on_before_backward = "on_before_backward" | |||
on_after_backward = "on_after_backward" | |||
on_before_optimizers_step = "on_before_optimizers_step" | |||
on_after_optimizers_step = "on_after_optimizers_step" | |||
on_before_zero_grad = "on_before_zero_grad" | |||
on_after_zero_grad = "on_after_zero_grad" | |||
on_validate_begin = "on_validate_begin" | |||
on_validate_end = "on_validate_end" | |||
class EventsList: | |||
@@ -169,20 +171,8 @@ class Filter: | |||
self.num_called += 1 | |||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; | |||
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; | |||
# 提取参数的逻辑; | |||
trainer = kwargs.get("trainer", None) | |||
if trainer is None and len(args) > 0: | |||
trainer = args[0] | |||
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer | |||
# 也会调用这个 module,但是 Controller 不会; | |||
param = (self, trainer) | |||
else: | |||
param = (self, ) | |||
if self._filter(*param): | |||
trainer = args[0] | |||
if self._filter(self, trainer): | |||
self.num_executed += 1 | |||
return fn(*args, **kwargs) | |||
@@ -278,13 +278,21 @@ class CallbackManager: | |||
pass | |||
@_transfer | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_validate_begin(self, trainer): | |||
pass | |||
@@ -10,12 +10,10 @@ from copy import deepcopy | |||
import fastNLP | |||
from .callback import Callback, HasMonitorCallback | |||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
from fastNLP.core.utils import apply_to_collection | |||
class CheckpointCallback(HasMonitorCallback): | |||
@@ -167,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): | |||
""" | |||
if self.save_topk is not None: | |||
monitor_value = self.get_monitor_value(results=results) | |||
if monitor_value is None: | |||
return | |||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
f"-{self._real_monitor}_{monitor_value}" | |||
@@ -178,8 +178,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
else: | |||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||
key=lambda x: self._topk_model[x]) | |||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||
self._topk_model[folder_name] = monitor_value | |||
_should_save = True | |||
self._topk_model.pop(_least_valuable_model) | |||
@@ -208,21 +207,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
**self.kwargs | |||
) | |||
def _get_validate_metric(self, res: Dict): | |||
""" | |||
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||
:param res: | |||
:return: | |||
""" | |||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||
if self._real_monitor != use_monitor: | |||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
self._real_monitor = use_monitor | |||
return value | |||
@property | |||
def folder_prefix(self): | |||
raise NotImplementedError("The `folder_prefix` is not specified") | |||
@@ -248,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -295,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -2,17 +2,18 @@ __all__ = [ | |||
'EarlyStopCallback' | |||
] | |||
from typing import Dict | |||
from typing import Dict, Union, Callable | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class EarlyStopCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||
""" | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param patience: 多少次 validate 不没有提升就停止。 | |||
""" | |||
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||
self.patience = patience | |||
def on_validate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if monitor_value is None: | |||
return | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
self.wait = 0 | |||
else: | |||
@@ -3,7 +3,7 @@ __all__ = [ | |||
] | |||
import os | |||
from typing import Optional, Callable | |||
from typing import Optional, Callable, Union | |||
from .callback import HasMonitorCallback | |||
from io import BytesIO | |||
import shutil | |||
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call | |||
class LoadBestModelCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | |||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | |||
model_load_fn:Optional[Callable] = None, | |||
delete_after_train:bool = True): | |||
""" | |||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 该 metric 值是否是越大越好。 | |||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
self.get_monitor_value(sanity_check_res) | |||
def on_validate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if monitor_value is None: | |||
return | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.real_save_folder: | |||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -1,9 +1,10 @@ | |||
from typing import Optional | |||
from typing import Optional, Union | |||
from fastNLP.core.log.logger import logger | |||
from difflib import SequenceMatcher | |||
from fastNLP.core.utils.utils import _get_fun_msg | |||
def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||
""" | |||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||
匹配。 | |||
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||
:param monitor: | |||
:param real_monitor: | |||
:param res: | |||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||
找到对应的 monitor | |||
""" | |||
if len(res)==0: | |||
return monitor, 0 | |||
return monitor, None | |||
if callable(monitor): | |||
try: | |||
monitor_value = monitor(res) | |||
except BaseException as e: | |||
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") | |||
raise e | |||
return monitor, monitor_value | |||
if monitor in res: | |||
return monitor, res[monitor] | |||
@@ -5,7 +5,7 @@ __all__ = [ | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, Dict, List, Callable, Union | |||
from typing import Any, Dict, List, Callable, Union, Tuple | |||
from numbers import Number | |||
import warnings | |||
@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): | |||
self.field_name = field_name # 标示当前 field 的名称 | |||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
r""" | |||
识别cell的类别与dimension的数量 | |||
@@ -197,7 +197,7 @@ class _MultiCollator: | |||
collator.set_input(*field_names) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
return self | |||
@@ -206,7 +206,7 @@ class AutoCollator(Collator): | |||
def __init__(self, as_numpy: bool): | |||
super(AutoCollator, self).__init__() | |||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
self.need_inputs = [] # 需要的 field name | |||
self.need_inputs = set() # 需要的 field name | |||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
self.field_dims = None # 每列数据单元维度 | |||
self.as_numpy = as_numpy | |||
@@ -214,10 +214,17 @@ class AutoCollator(Collator): | |||
def __call__(self, ins_lst: List[Dict]) -> dict: | |||
if len(self.need_inputs) == 0: | |||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
# 第一种情况,设置了 set_input 的值 | |||
# 第二种情况, 根据数据的类型的判断是否 padding | |||
if self.field_dtypes is None and self.field_dims is None: | |||
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||
field_dtypes, field_dims = {}, {} | |||
for key, value in ins_lst[0].items(): | |||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
self.field_dtypes = field_dtypes | |||
self.field_dims = field_dims | |||
pack_ins_lst, pad_ins_lst = {field_name: [] | |||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
@@ -233,13 +240,13 @@ class AutoCollator(Collator): | |||
if len(self.pad_field_value.keys()) > 0: | |||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
drop_field_names = [] | |||
non_pad_field_names = [] | |||
for k, v in self.pad_field_value.items(): | |||
if v is None: | |||
drop_field_names.append(k) | |||
non_pad_field_names.append(k) | |||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
for field_name in drop_field_names: | |||
for field_name in non_pad_field_names: | |||
field_array = pack_ins_lst.pop(field_name) | |||
pad_ins_lst[field_name] = np.array(field_array) | |||
@@ -269,7 +276,7 @@ class AutoCollator(Collator): | |||
def set_input(self, *field_names): | |||
for field_name in field_names: | |||
self.need_inputs.append(field_name) | |||
self.need_inputs.add(field_name) | |||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
@@ -11,11 +11,12 @@ __all__ = [ | |||
from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from .loops import Loop, EvaluateBatchLoop | |||
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||
match_and_substitute_params, f_rich_progress | |||
from fastNLP.core.metrics import Metric | |||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | |||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
from fastNLP.core.log import logger | |||
@@ -38,10 +39,11 @@ class Evaluator: | |||
driver: Union[str, Driver] = 'single', | |||
device: Optional[Union[int, List[int], str]] = None, | |||
batch_step_fn: Optional[callable] = None, | |||
mode: str = "validate", | |||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
fp16: Optional[bool] = False, | |||
model_wo_auto_param_call: bool = False, | |||
fp16: bool = False, | |||
verbose: int = 1, | |||
**kwargs | |||
): | |||
@@ -61,6 +63,9 @@ class Evaluator: | |||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
:param fp16: 是否使用 fp16 。 | |||
:param verbose: 是否打印 evaluate 的结果。 | |||
:param kwargs: | |||
@@ -83,13 +88,13 @@ class Evaluator: | |||
self.model = model | |||
self.metrics = metrics | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||
self.device = device | |||
self.verbose = verbose | |||
assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||
"two parameters." | |||
if batch_step_fn is not None: | |||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
self.batch_step_fn = batch_step_fn | |||
self.mode = mode | |||
@@ -131,6 +136,7 @@ class Evaluator: | |||
if self.progress_bar == 'auto': | |||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||
self.driver.check_evaluator_mode(self.mode) | |||
self.driver.barrier() | |||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
@@ -150,8 +156,6 @@ class Evaluator: | |||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | |||
self.driver.check_evaluator_mode(self.mode) | |||
if self.mode == 'validate': | |||
assert self.driver.has_validate_dataloaders() | |||
else: | |||
@@ -219,7 +223,6 @@ class Evaluator: | |||
def remove_progress_bar(self, dataloader_name): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||
delattr(self, '_rich_task_id') | |||
elif self.progress_bar == 'raw': | |||
desc = 'Evaluation ends' | |||
@@ -230,7 +233,6 @@ class Evaluator: | |||
def finally_progress_bar(self): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() | |||
delattr(self, '_rich_task_id') | |||
@property | |||
@@ -355,20 +357,24 @@ class _MetricsWrapper: | |||
if is_dataclass(outputs): | |||
outputs = dataclass_to_dict(outputs) | |||
for metric in self._metrics: | |||
args = [] | |||
if not isinstance(batch, dict): | |||
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
f"the output of model to update metric.") | |||
else: | |||
args.append(batch) | |||
if not isinstance(outputs, dict): | |||
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
if isinstance(metric, Metric): | |||
auto_param_call(metric.update, batch, outputs) | |||
# 这样在 auto_param_call 报错的时候才清晰。 | |||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
elif _is_torchmetrics_metric(metric): | |||
auto_param_call(metric.update, batch, outputs) | |||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
elif _is_allennlp_metric(metric): | |||
auto_param_call(metric.__call__, batch, outputs) | |||
auto_param_call(metric.__call__, outputs, *args) | |||
elif _is_paddle_metric(metric): | |||
res = auto_param_call(metric.compute, batch, outputs) | |||
res = auto_param_call(metric.compute, outputs, *args) | |||
metric.update(res) | |||
def reset(self): | |||
@@ -7,6 +7,7 @@ from typing import Optional, Callable | |||
from .loop import Loop | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import match_and_substitute_params | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class TrainBatchLoop(Loop): | |||
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): | |||
try: | |||
trainer.on_fetch_data_begin() | |||
batch = next(dataloader) | |||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
indices = get_batch_indices() | |||
batch = trainer.move_data_to_device(batch) | |||
trainer.on_fetch_data_end() | |||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
batch = trainer.move_data_to_device(batch) | |||
except StopIteration: | |||
break | |||
except BaseException as e: # TODO 把这里的信息写入进去 | |||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||
break | |||
except BaseException as e: | |||
if indices: | |||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||
raise e | |||
@@ -14,6 +14,7 @@ __all__ = [ | |||
from .loops import Loop, TrainBatchLoop | |||
from .utils import State, TrainerState | |||
from .utils.utils import check_validate_every | |||
from .evaluator import Evaluator | |||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | |||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
@@ -42,15 +44,16 @@ class Trainer(TrainerEventTrigger): | |||
validate_dataloaders=None, | |||
batch_step_fn: Optional[Callable] = None, | |||
validate_batch_step_fn: Optional[Callable] = None, | |||
validate_mode: str = "validate", | |||
validate_mode: Union[str, callable] = 'validate', | |||
callbacks: Union[List[Callback], Callback, None] = None, | |||
metrics: Optional[dict] = None, | |||
validate_every: Optional[Union[int, callable]] = -1, | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
accumulation_steps: int = 1, | |||
fp16: bool = False, | |||
monitor: str = None, | |||
monitor: Union[str, callable] = None, | |||
larger_better: bool = True, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
@@ -89,11 +92,8 @@ class Trainer(TrainerEventTrigger): | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||
之后,每个 epoch 结束进行一次 validate 。 | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | |||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | |||
@@ -102,12 +102,15 @@ class Trainer(TrainerEventTrigger): | |||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | |||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | |||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。 | |||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||
:param kwargs: 一些其它的可能需要的参数; | |||
@@ -126,20 +129,21 @@ class Trainer(TrainerEventTrigger): | |||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | |||
""" | |||
# TODO 是不是可以加一个参数让用户现在关掉参数匹配。 | |||
self.marker = marker | |||
self.model = model | |||
self.driver_name = driver | |||
self.marker = marker | |||
if isinstance(driver, str): | |||
self.driver_name = driver | |||
else: | |||
self.driver_name = driver.__class__.__name__ | |||
self.device = device | |||
self.optimizers = optimizers | |||
self.fp16 = fp16 | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||
"two parameters." | |||
self.batch_step_fn = batch_step_fn | |||
if batch_step_fn is not None: | |||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | |||
else: | |||
self.check_batch_step_fn = lambda *args, **kwargs: ... | |||
@@ -155,6 +159,8 @@ class Trainer(TrainerEventTrigger): | |||
elif accumulation_steps < 0: | |||
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") | |||
self.accumulation_steps = accumulation_steps | |||
# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧 | |||
self.driver = choose_driver( | |||
model=model, | |||
driver=driver, | |||
@@ -171,6 +177,7 @@ class Trainer(TrainerEventTrigger): | |||
validate_every=validate_every, | |||
input_mapping=input_mapping, | |||
output_mapping=output_mapping, | |||
model_wo_auto_param_call=model_wo_auto_param_call, | |||
accumulation_steps=accumulation_steps, | |||
fp16=fp16, | |||
marker=marker, | |||
@@ -212,17 +219,11 @@ class Trainer(TrainerEventTrigger): | |||
if metrics is not None and validate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | |||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||
self.evaluator = None | |||
self.epoch_validate = lambda *args, **kwargs: ... | |||
self.step_validate = lambda *args, **kwargs: ... | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and validate_dataloaders is not None: | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
check_validate_every(validate_every) | |||
self.evaluator = Evaluator( | |||
model=model, | |||
dataloaders=validate_dataloaders, | |||
@@ -239,16 +240,6 @@ class Trainer(TrainerEventTrigger): | |||
progress_bar=kwargs.get('progress_bar', 'auto') | |||
) | |||
if callable(validate_every): | |||
self._step_validate_filter = Filter(filter_fn=validate_every) | |||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||
elif validate_every < 0: | |||
self._epoch_validate_filter = Filter(every=-validate_every) | |||
else: | |||
# validate_every > 0 | |||
self._step_validate_filter = Filter(every=validate_every) | |||
self.metrics = metrics | |||
self.validate_every = validate_every | |||
@@ -317,6 +308,8 @@ class Trainer(TrainerEventTrigger): | |||
try: | |||
while self.cur_epoch_idx < self.n_epochs: | |||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
self.driver.set_model_mode("train") | |||
self.on_train_epoch_begin() | |||
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | |||
@@ -345,31 +338,37 @@ class Trainer(TrainerEventTrigger): | |||
raise e | |||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | |||
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: | |||
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: | |||
trainer.on_validate_begin() | |||
_validate_res: dict = validate_fn() | |||
trainer.on_validate_end(_validate_res) | |||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
def step_validate(self): | |||
""" | |||
在每个 batch 结束后调用,根据设置执行 evaluate 。 | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if callable(self.validate_every): | |||
self.step_validate = self._step_validate_filter(partial( | |||
_validate_fn, | |||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||
self | |||
)) | |||
elif self.validate_every < 0: | |||
self.epoch_validate = self._epoch_validate_filter(partial( | |||
_validate_fn, | |||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||
self | |||
)) | |||
else: | |||
# validate_every > 0 | |||
self.step_validate = self._step_validate_filter(partial( | |||
_validate_fn, | |||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||
self | |||
)) | |||
if self.validate_every(self): | |||
self.run_evaluate() | |||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||
self.run_evaluate() | |||
def epoch_validate(self): | |||
""" | |||
在每个 epoch 结束后调用,根据设置执行 evaluate 。 | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||
validate_every = -self.validate_every | |||
if self.cur_epoch_idx % validate_every == 0: | |||
self.run_evaluate() | |||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||
r""" | |||
@@ -400,9 +399,8 @@ class Trainer(TrainerEventTrigger): | |||
def wrapper(fn: Callable) -> Callable: | |||
cls._custom_callbacks[marker].append((event, fn)) | |||
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ | |||
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ | |||
"`Callback` with the same callback time." | |||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | |||
_check_valid_parameters_number(fn, callback_fn_args) | |||
return fn | |||
return wrapper | |||
@@ -431,9 +429,11 @@ class Trainer(TrainerEventTrigger): | |||
2. 函数作用 | |||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | |||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||
"on_after_zero_grad") / | |||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||
"on_after_zero_grad") | |||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | |||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | |||
@@ -443,10 +443,12 @@ class Trainer(TrainerEventTrigger): | |||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | |||
""" | |||
if check_mode: | |||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||
"on_before_zero_grad", "on_after_zero_grad") | |||
else: | |||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||
"on_before_zero_grad", "on_after_zero_grad") | |||
_not_called_callback_fns = [] | |||
for each_callback_fn in callbacks: | |||
if each_callback_fn in self.callback_manager.callback_fns: | |||
@@ -498,8 +500,6 @@ class Trainer(TrainerEventTrigger): | |||
@driver.setter | |||
def driver(self, driver: Driver): | |||
driver.trainer = self | |||
driver.model = self.model | |||
self._driver = driver | |||
@property | |||
@@ -591,7 +591,9 @@ class Trainer(TrainerEventTrigger): | |||
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | |||
# 2. trainer_state; | |||
states = {"callback_states": self.on_save_checkpoint(), | |||
"trainer_state": self.trainer_state.state_dict()} | |||
"trainer_state": self.trainer_state.state_dict(), | |||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||
} | |||
# 3. validate filter state; | |||
if self.evaluator is not None: | |||
@@ -668,6 +670,10 @@ class Trainer(TrainerEventTrigger): | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||
self.batch_idx_in_epoch | |||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
# 5. 恢复所有 callback 的状态; | |||
self.on_load_checkpoint(states["callback_states"]) | |||
@@ -692,13 +698,15 @@ class Trainer(TrainerEventTrigger): | |||
def zero_grad(self): | |||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||
self.on_before_zero_grad(self.driver.optimizers) | |||
self.on_before_zero_grad(self.optimizers) | |||
self.driver.zero_grad(self.set_grad_to_none) | |||
self.on_after_zero_grad(self.optimizers) | |||
def step(self): | |||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||
self.on_before_optimizer_step(self.driver.optimizers) | |||
self.on_before_optimizers_step(self.optimizers) | |||
self.driver.step() | |||
self.on_after_optimizers_step(self.optimizers) | |||
def move_data_to_device(self, batch): | |||
return self.driver.move_data_to_device(batch) | |||
@@ -796,4 +804,19 @@ class Trainer(TrainerEventTrigger): | |||
def total_batches(self, total_batches: int): | |||
self.trainer_state.total_batches = total_batches | |||
""" driver property """ | |||
@property | |||
def model_device(self): | |||
return self.driver.model_device | |||
@property | |||
def data_device(self): | |||
return self.driver.data_device | |||
@@ -60,7 +60,7 @@ class TrainerState: | |||
cur_epoch_idx: 当前正在运行第几个 epoch; | |||
global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
total_batches: 每一个 epoch 会 forward 多少个 step; | |||
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
""" | |||
n_epochs: Optional[int] = None # 无论如何重新算 | |||
@@ -1,8 +1,9 @@ | |||
from collections.abc import Iterator | |||
import inspect | |||
from typing import Dict | |||
from fastNLP.core.callbacks import CallbackManager | |||
from .state import TrainerState | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
class TrainerEventTrigger: | |||
@@ -68,12 +69,18 @@ class TrainerEventTrigger: | |||
def on_after_backward(self): | |||
self.callback_manager.on_after_backward(self) | |||
def on_before_optimizer_step(self, optimizers): | |||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||
def on_before_optimizers_step(self, optimizers): | |||
self.callback_manager.on_before_optimizers_step(self, optimizers) | |||
def on_after_optimizers_step(self, optimizers): | |||
self.callback_manager.on_after_optimizers_step(self, optimizers) | |||
def on_before_zero_grad(self, optimizers): | |||
self.callback_manager.on_before_zero_grad(self, optimizers) | |||
def on_after_zero_grad(self, optimizers): | |||
self.callback_manager.on_after_zero_grad(self, optimizers) | |||
def on_validate_begin(self): | |||
self.callback_manager.on_validate_begin(self) | |||
@@ -119,5 +126,8 @@ class _TruncatedDataLoader: | |||
return getattr(self.dataloader, item) | |||
def check_validate_every(validate_every): | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
if callable(validate_every): | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||
""" | |||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||
@@ -178,10 +178,11 @@ class DataSet: | |||
elif isinstance(idx, slice): | |||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | |||
data_set = DataSet() | |||
dataset = DataSet() | |||
for field_name, field in self.field_arrays.items(): | |||
data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||
return data_set | |||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
return dataset | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||
@@ -192,6 +193,7 @@ class DataSet: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance = self[i] | |||
dataset.append(instance) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -674,6 +676,8 @@ class DataSet: | |||
dev_set.append(self[idx]) | |||
for idx in train_indices: | |||
train_set.append(self[idx]) | |||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||
train_set.collate_fns = deepcopy(self.collate_fns) | |||
return dev_set, train_set | |||
@@ -788,13 +792,14 @@ class DataSet: | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: dataset存在的field_name | |||
:param val: 默认为0 | |||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
:return: | |||
""" | |||
# TODO 不能为空 | |||
for field_name in field_names: | |||
self.collate_fns.set_pad_val(field_name, val=val) | |||
@@ -805,6 +810,7 @@ class DataSet: | |||
:param field_names: | |||
:return: | |||
""" | |||
# | |||
self.collate_fns.set_input(*field_names) | |||
def get_collator(self) -> _MultiCollator: | |||
@@ -66,7 +66,7 @@ class JittorDriver(Driver): | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
@@ -74,7 +74,7 @@ class JittorDriver(Driver): | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@@ -10,6 +10,8 @@ from .utils import ( | |||
_MODE_PARAMETER, | |||
get_device_from_visible, | |||
reset_seed, | |||
replace_sampler, | |||
replace_batch_sampler, | |||
) | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
@@ -19,8 +21,17 @@ from fastNLP.core.utils import ( | |||
paddle_move_data_to_device, | |||
is_in_paddle_dist, | |||
) | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.samplers import ( | |||
RandomBatchSampler, | |||
ReproducibleSampler, | |||
ReproducibleBatchSampler, | |||
RandomSampler, | |||
UnrepeatedSampler, | |||
UnrepeatedSequentialSampler, | |||
re_instantiate_sampler, | |||
conversion_between_reproducible_and_unrepeated_sampler, | |||
) | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -93,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): | |||
# 我们就直接将 model_device 置为 None; | |||
self._model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn): | |||
if isinstance(batch, Dict): | |||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
@@ -105,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._train_signature_fn = model.forward | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "validate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._validate_signature_fn = model.forward | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
@@ -235,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
""" | |||
if self.local_rank == 0: | |||
# 是 rank0 的话,则拉起其它子进程 | |||
print("in launcher") | |||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | |||
launcher.launch() | |||
# 设置参数和初始化分布式环境 | |||
@@ -253,7 +261,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | |||
根据 paddle 设置的环境变量来获得各种属性 | |||
""" | |||
print("set_from_env") | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
@@ -267,9 +274,9 @@ class PaddleFleetDriver(PaddleDriver): | |||
**self._fleet_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._configured = True | |||
@@ -312,67 +319,90 @@ class PaddleFleetDriver(PaddleDriver): | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员 | |||
# 因此用 type() 进行严格的判断 | |||
if type(dataloader.batch_sampler) == BatchSampler: | |||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) | |||
else: | |||
shuffle = dataloader.batch_sampler.shuffle | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
# trainer, evaluator | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||
"control.") | |||
else: | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
return dataloader | |||
# trainer | |||
elif dist == "dist": | |||
args = self.get_dataloader_args(dataloader) | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler.set_distributed( | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
batch_sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return dataloader | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
sampler = RandomSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||
) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
dataloader.batch_sampler.sampler = sampler | |||
return dataloader | |||
return replace_sampler(dataloader, sampler) | |||
# evaluator | |||
elif dist == "unrepeatdist": | |||
sampler = UnrepeatedRandomSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
) | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
) | |||
else: | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
) | |||
dataloader.batch_sampler.sampler = sampler | |||
return dataloader | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
@@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
if driver not in {"paddle", "fleet"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | |||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | |||
# 优先级 user > cuda | |||
# 判断单机情况 device 的合法性 | |||
# 分布式情况下通过 world_device 判断 | |||
if user_visible_devices != "": | |||
_could_use_device_num = len(user_visible_devices.split(",")) | |||
elif cuda_visible_devices is not None: | |||
_could_use_device_num = len(cuda_visible_devices.split(",")) | |||
else: | |||
_could_use_device_num = paddle.device.cuda.device_count() | |||
if user_visible_devices is None: | |||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||
_could_use_device_num = len(user_visible_devices.split(",")) | |||
if isinstance(device, int): | |||
if device < 0 and device != -1: | |||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||
# if device >= _could_use_device_num: | |||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
device = f"gpu:{device}" | |||
if device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
if device != -1: | |||
device = f"gpu:{device}" | |||
else: | |||
device = list(range(_could_use_device_num)) | |||
elif isinstance(device, Sequence) and not isinstance(device, str): | |||
device = list(set(device)) | |||
for each in device: | |||
@@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | |||
elif each < 0: | |||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | |||
elif each >= _could_use_device_num: | |||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||
" the available gpu number.") | |||
if len(device) == 1: | |||
# 传入了 [1] 这样的,视为单卡。 | |||
device = device[0] | |||
@@ -1,21 +1,36 @@ | |||
import os | |||
import random | |||
from typing import Union, Optional, Callable, Dict | |||
from typing import Union, Optional, Dict | |||
from pathlib import Path | |||
from functools import partial | |||
from dataclasses import dataclass | |||
import numpy as np | |||
from .utils import _build_fp16_env | |||
from .utils import _build_fp16_env, optimizer_state_to_device | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS | |||
from fastNLP.envs import ( | |||
FASTNLP_SEED_WORKERS, | |||
FASTNLP_MODEL_FILENAME, | |||
FASTNLP_CHECKPOINT_FILENAME, | |||
FASTNLP_GLOBAL_RANK, | |||
rank_zero_call, | |||
) | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.io import DataLoader, IterableDataset | |||
from paddle.io import ( | |||
DataLoader, | |||
IterableDataset, | |||
Dataset, | |||
Sampler, | |||
BatchSampler, | |||
RandomSampler, | |||
) | |||
from paddle.optimizer import Optimizer | |||
_reduces = { | |||
@@ -41,6 +56,9 @@ class PaddleDriver(Driver): | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def zero_grad(self, set_to_none: bool = False): | |||
r""" | |||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||
@@ -48,8 +66,8 @@ class PaddleDriver(Driver): | |||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | |||
""" | |||
# if set_to_none: | |||
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
if set_to_none: | |||
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
for optimizer in self.optimizers: | |||
optimizer.clear_grad() | |||
@@ -69,6 +87,8 @@ class PaddleDriver(Driver): | |||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | |||
if isinstance(dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
@@ -79,6 +99,9 @@ class PaddleDriver(Driver): | |||
f"type, not {type(each_dataloader)}.") | |||
if isinstance(each_dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: | |||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||
f"`batch_sampler` and `batch_size` should be set.") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -110,7 +133,7 @@ class PaddleDriver(Driver): | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@@ -153,45 +176,55 @@ class PaddleDriver(Driver): | |||
getattr(self.model, mode)() | |||
@rank_zero_call | |||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): | |||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
r""" | |||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; | |||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 | |||
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 | |||
:param kwargs: | |||
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 | |||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||
的文档: | |||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||
:return: | |||
""" | |||
if model_save_fn is not None: | |||
model_save_fn(filepath) | |||
model = self.unwrap_model() | |||
if isinstance(filepath, Path): | |||
filepath = str(filepath) | |||
if only_state_dict: | |||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
paddle.save(states, filepath) | |||
else: | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
paddle.save(model.state_dict(), filepath) | |||
else: | |||
input_spec = kwargs.get("input_spec", None) | |||
if input_spec is None: | |||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") | |||
paddle.jit.save(model, filepath, input_spec) | |||
# paddle 在保存整个模型时需要传入额外参数 | |||
input_spec = kwargs.get("input_spec", None) | |||
if input_spec is None: | |||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||
paddle.jit.save(model, filepath, input_spec) | |||
@staticmethod | |||
@rank_zero_call | |||
def load_model(filepath: str, load_dict: bool = True): | |||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
r""" | |||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | |||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | |||
即保存了整个模型时,这个参数必须也为False | |||
:return: 返回加载指定文件后的结果; | |||
:param only_state_dict: 是否加载state_dict,默认为True。 | |||
:param kwargs: | |||
:return: | |||
""" | |||
if load_dict: | |||
return paddle.load(filepath) | |||
else: | |||
return paddle.jit.load(filepath) | |||
model = self.unwrap_model() | |||
if isinstance(filepath, Path): | |||
filepath = str(filepath) | |||
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict | |||
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 | |||
dirname, filename = os.path.split(filepath) | |||
if not only_state_dict and dirname == "": | |||
# 如果传入的是单个文件,则加上相对路径 | |||
filepath = os.path.join(".", filepath) | |||
model.load_dict(paddle.load(filepath)) | |||
@rank_zero_call | |||
def save(self, folder, states: Dict): | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | |||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | |||
@@ -203,48 +236,114 @@ class PaddleDriver(Driver): | |||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | |||
传入的值保持一致。 | |||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||
:return: | |||
""" | |||
# 1. 保存模型的状态; | |||
model = self.unwrap_model() | |||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||
states["model_state_dict"] = model_state_dict | |||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
sampler_states = sampler.state_dict() | |||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
try: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
# 2. 保存 optimizers 的状态; | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||
if only_state_dict: | |||
logger.debug("Save model state dict.") | |||
else: | |||
logger.debug("Save model.") | |||
# 3. 保存 optimizers 的状态; | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizer_state = optimizer.state_dict() | |||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} | |||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
paddle.save(states, folder) | |||
def load(self, filepath) -> Dict: | |||
r""" | |||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; | |||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 | |||
因此 save 函数和 load 函数的接受和返回值应该是对应的; | |||
该函数需要在所有 rank 上执行。 | |||
logger.debug("Save optimizer state dict.") | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
:param filepath: 保存断点重训的状态的文件名; | |||
:return: 需要返回 save 函数输入的 states 内容; | |||
""" | |||
states = paddle.load(filepath) | |||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
# 1. 加载 optimizers 的状态; | |||
optimizers_state_dict = states["optimizers_state_dict"] | |||
for i in range(len(self.optimizers)): | |||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | |||
logger.debug("Load optimizer state dict.") | |||
# 2. 加载模型状态; | |||
model = self.unwrap_model() | |||
model.load_dict(states["model_state_dict"]) | |||
if should_load_model: | |||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) | |||
if only_state_dict: | |||
logger.debug("Load model state dict.") | |||
else: | |||
logger.debug("Load model.") | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
sampler = dataloader_args.sampler | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = RandomBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
) | |||
sampler.load_state_dict(states['sampler_states']) | |||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if not isinstance(sampler, ReproducibleBatchSampler): | |||
if dataloader_args.drop_last: | |||
batch_idx_in_epoch = len( | |||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
else: | |||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||
# sampler 是 batch_sampler; | |||
else: | |||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||
self.barrier() | |||
return states | |||
def get_evaluate_context(self): | |||
@@ -282,7 +381,7 @@ class PaddleDriver(Driver): | |||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | |||
""" | |||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | |||
global_rank = rank if rank is not None else rank_zero_call.rank | |||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||
# TODO gpu | |||
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | |||
# back out the base seed so we can use all the bits | |||
@@ -313,3 +412,64 @@ class PaddleDriver(Driver): | |||
""" | |||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||
@staticmethod | |||
def get_dataloader_args(dataloader: "DataLoader"): | |||
""" | |||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||
""" | |||
@dataclass | |||
class Res: | |||
dataset: Optional[Dataset] = None | |||
batch_sampler: Optional[BatchSampler] = None | |||
sampler: Optional[Sampler] = None | |||
batch_size: Optional[int] = None | |||
shuffle: Optional[bool] = None | |||
drop_last: Optional[bool] = None | |||
res = Res() | |||
# paddle 的 DataLoader 一定会有 dataset 属性; | |||
res.dataset = dataloader.dataset | |||
if dataloader.batch_sampler is not None: | |||
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None | |||
res.batch_sampler = dataloader.batch_sampler | |||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||
else: | |||
dataloader_iter = iter(dataloader) | |||
pre_sample = next(dataloader_iter) | |||
res.batch_size = pre_sample.shape[0] | |||
if hasattr(dataloader.batch_sampler, "sampler"): | |||
res.sampler = dataloader.batch_sampler.sampler | |||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||
res.shuffle = True | |||
else: | |||
res.shuffle = False | |||
# RandomBatchSampler 的情况 | |||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||
res.sampler = batch_sampler.sampler | |||
if hasattr(batch_sampler.sampler, "shuffle"): | |||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||
elif isinstance(batch_sampler.sampler, RandomSampler): | |||
res.shuffle = True | |||
else: | |||
res.shuffle = False | |||
else: | |||
res.sampler = None | |||
res.shuffle = False | |||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||
else: | |||
res.drop_last = False | |||
return res |
@@ -2,6 +2,7 @@ import os | |||
from typing import Optional, Dict, Union | |||
from .paddle_driver import PaddleDriver | |||
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.utils import ( | |||
@@ -10,7 +11,12 @@ from fastNLP.core.utils import ( | |||
get_paddle_device_id, | |||
paddle_move_data_to_device, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ( | |||
ReproducibleBatchSampler, | |||
RandomBatchSampler, | |||
ReproducibleSampler, | |||
re_instantiate_sampler, | |||
) | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -22,16 +28,13 @@ __all__ = [ | |||
] | |||
class PaddleSingleDriver(PaddleDriver): | |||
def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs): | |||
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): | |||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
if device is None: | |||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | |||
if isinstance(device, int): | |||
self.model_device = get_paddle_gpu_str(device) | |||
else: | |||
self.model_device = device | |||
self.model_device = get_paddle_gpu_str(device) | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
@@ -93,18 +96,18 @@ class PaddleSingleDriver(PaddleDriver): | |||
self._test_signature_fn = model.forward | |||
def setup(self): | |||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||
device_id = get_paddle_device_id(self.model_device) | |||
if user_visible_devices is not None and user_visible_devices != "": | |||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||
device_id = user_visible_devices.split(",")[device_id] | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||
paddle.device.set_device("gpu:0") | |||
self.model.to("gpu:0") | |||
device = self.model_device | |||
if device != "cpu": | |||
device_id = get_paddle_device_id(device) | |||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||
device = get_device_from_visible(device, output_type=str) | |||
paddle.device.set_device(device) | |||
self.model.to(device) | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
@@ -118,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): | |||
self.grad_scaler.update() | |||
def validate_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
@@ -133,38 +136,40 @@ class PaddleSingleDriver(PaddleDriver): | |||
r""" | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||
:return: 将移动到指定机器上的 batch 对象返回; | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
device = get_device_from_visible(self.data_device) | |||
return paddle_move_data_to_device(batch, device) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持IteratorDataset | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
"FastNLP does not support `IteratorDataset` now." | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dataloader.batch_sampler = dist | |||
return dataloader | |||
if isinstance(dist, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleSampler): | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
) | |||
dataloader.batch_sampler = batch_sampler | |||
return dataloader | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
return dataloader | |||
@@ -4,12 +4,14 @@ import struct | |||
import random | |||
import inspect | |||
import numpy as np | |||
from copy import deepcopy | |||
from contextlib import ExitStack, closing | |||
from enum import IntEnum | |||
from typing import Dict, Optional, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call | |||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | |||
from fastNLP.core.samplers import RandomSampler | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.log import logger | |||
@@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle import nn | |||
from paddle.nn import Layer | |||
from paddle.io import DataLoader, BatchSampler | |||
from paddle.io import DataLoader, BatchSampler, Dataset | |||
from paddle.amp import auto_cast, GradScaler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | |||
@@ -85,7 +87,7 @@ class ForwardState(IntEnum): | |||
TEST = 2 | |||
PREDICT = 3 | |||
_MODE_PARAMETER = "_forward_state" | |||
_MODE_PARAMETER = "forward_state" | |||
class _FleetWrappingModel(Layer): | |||
""" | |||
@@ -151,24 +153,25 @@ class _FleetWrappingModel(Layer): | |||
def forward(self, batch, **kwargs) -> Dict: | |||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||
if _forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict): | |||
if forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif _forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif _forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif _forward_state == ForwardState.PREDICT: | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
@@ -205,7 +208,6 @@ class DummyGradScaler: | |||
def state_dict(self): | |||
return {} | |||
def _build_fp16_env(dummy=False): | |||
if dummy: | |||
auto_cast = ExitStack | |||
@@ -255,61 +257,77 @@ def get_host_name_ip(): | |||
except: | |||
return None | |||
def get_device_from_visible(device: Union[str, int]): | |||
def get_device_from_visible(device: Union[str, int], output_type=int): | |||
""" | |||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | |||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | |||
:param devices:未转化的设备名 | |||
:param device: 未转化的设备名 | |||
:param output_type: 返回值的类型 | |||
:return: 转化后的设备id | |||
""" | |||
if output_type not in [int, str]: | |||
raise ValueError("Parameter `output_type` should be one of these types: [int, str]") | |||
if device == "cpu": | |||
return device | |||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||
idx = get_paddle_device_id(device) | |||
if cuda_visible_devices is None or cuda_visible_devices == "": | |||
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | |||
return idx | |||
raise RuntimeError("This situation should not happen, please report us this bug.") | |||
else: | |||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | |||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
if user_visible_devices is not None and user_visible_devices != "": | |||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||
idx = user_visible_devices.split(",")[idx] | |||
else: | |||
idx = str(idx) | |||
if user_visible_devices is None: | |||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||
idx = user_visible_devices.split(",")[idx] | |||
cuda_visible_devices_list = cuda_visible_devices.split(',') | |||
assert idx in cuda_visible_devices_list, "Can't find "\ | |||
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ | |||
% (idx, cuda_visible_devices) | |||
if idx not in cuda_visible_devices_list: | |||
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") | |||
res = cuda_visible_devices_list.index(idx) | |||
return res | |||
if output_type == int: | |||
return res | |||
else: | |||
return f"gpu:{res}" | |||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||
# 拿到实例属性; | |||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | |||
""" | |||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 | |||
考虑了用户自己定制了 DataLoader 的情形。 | |||
""" | |||
# 拿到非下划线开头的实例属性; | |||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||
# 中寻找; | |||
# 中寻找;VAR_KEYWORD 代表 **kwargs | |||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||
if has_variadic_kwargs: | |||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||
del init_params["self"] | |||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||
non_default_params = {name for name, p in init_params.items() if | |||
name in instance_attrs and p.default != instance_attrs[name]} | |||
# add `dataset` as it might have been replaced with `*args` | |||
non_default_params.add("dataset") | |||
# 收集不是默认值的参数和它的值 | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) | |||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||
reconstruct_args.update({ | |||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||
"persistent_workers": dataloader._persistent_workers, | |||
}) | |||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | |||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | |||
# 也即它们没有在初始化函数和实例成员中同时出现 | |||
required_args = { | |||
p.name | |||
for p in init_params.values() | |||
@@ -323,12 +341,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||
required_args = sorted(required_args) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||
f"The missing attributes are {required_args}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||
"manually add the `DistributedBatchSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||
) | |||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||
@@ -340,12 +355,28 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||
missing_kwargs = sorted(missing_kwargs) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||
f"The missing arguments are {missing_kwargs}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||
"manually add the `DistributedBatchSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||
) | |||
return type(dataloader)(**reconstruct_args) | |||
def replace_sampler(dataloader, new_sampler): | |||
""" | |||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 | |||
""" | |||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | |||
new_batch_sampler.sampler = new_sampler | |||
return replace_batch_sampler(dataloader, new_batch_sampler) | |||
def optimizer_state_to_device(state, device): | |||
new_state = {} | |||
for name, param in state.items(): | |||
if isinstance(param, dict): | |||
new_state[name] = optimizer_state_to_device(param, device) | |||
elif isinstance(param, paddle.Tensor): | |||
new_state[name] = paddle_to(param, device).clone() | |||
else: | |||
new_state[name] = param | |||
return new_state |
@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import BatchSampler | |||
__all__ = [ | |||
'TorchDDPDriver' | |||
@@ -167,6 +168,7 @@ class TorchDDPDriver(TorchDriver): | |||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||
""" | |||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
if isinstance(model, torch.nn.DataParallel): | |||
@@ -202,8 +204,8 @@ class TorchDDPDriver(TorchDriver): | |||
# 我们就直接将 model_device 置为 None; | |||
self.model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn): | |||
if isinstance(batch, Dict): | |||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return step_fn(batch) | |||
@@ -214,7 +216,7 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
@@ -222,7 +224,7 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
@@ -230,14 +232,11 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._test_signature_fn = model.forward | |||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
# if self.outside_ddp and self._data_device is None: | |||
# raise RuntimeError("When you initialize your ddp out of our control, the parameter " | |||
# "`data_device` can not be None.") | |||
if isinstance(self._data_device, int): | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||
@@ -349,9 +348,9 @@ class TorchDDPDriver(TorchDriver): | |||
**self._ddp_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._configured = True | |||
@@ -472,12 +471,11 @@ class TorchDDPDriver(TorchDriver): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
else: | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_sampler(dataloader, dist) | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) | |||
return dataloader | |||
# trainer | |||
elif dist == "dist": | |||
@@ -526,18 +524,11 @@ class TorchDDPDriver(TorchDriver): | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
@@ -3,28 +3,20 @@ import pickle | |||
_pickler = pickle.Pickler | |||
_unpickler = pickle.Unpickler | |||
from typing import Any, List | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import distributed as dist | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupMPI | |||
except ImportError: | |||
_MPI_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||
except ImportError: | |||
_NCCL_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupGloo | |||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
except ImportError: | |||
_GLOO_AVAILABLE = False | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupGloo | |||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
except ImportError: | |||
pass | |||
from fastNLP.core.utils import apply_to_collection | |||
@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
) | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||
""" | |||
从其它 rank gather 东西到 dst rank 。 | |||
@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
@@ -193,7 +188,7 @@ def _to_device(tensor, device): | |||
return tensor.contiguous().to(device) | |||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||
""" | |||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
:param group: | |||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
""" | |||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
if isinstance(obj, torch.Tensor): | |||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||
dist.all_gather(objs, obj, group=group) | |||
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
return objs | |||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||
""" | |||
将 src 上的 obj 对象广播到其它 rank 上。 | |||
@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
:param group: | |||
:return: | |||
""" | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
cur_rank = dist.get_rank(group) | |||
if cur_rank == src: | |||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||
@@ -335,19 +333,21 @@ def all_gather_object(object_list, obj, group=None): | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if dist._rank_not_in_group(group): | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
if is_nccl_backend: | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
else: | |||
current_device = torch.cuda.current_device() | |||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
if is_nccl_backend: | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
# Gather all local sizes. This is so that we can find the max size, and index | |||
# until the correct size when deserializing the tensors. | |||
group_size = dist.get_world_size(group=group) | |||
@@ -378,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): | |||
tensor = tensor.cpu() | |||
tensor_size = object_size_list[i] | |||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
return object_list |
@@ -13,7 +13,7 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||
from fastNLP.core.log import logger | |||
@@ -102,29 +102,21 @@ class TorchSingleDriver(TorchDriver): | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def validate_step(self, batch) -> Dict: | |||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
@@ -148,7 +140,7 @@ class TorchSingleDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||
class TorchDriver(Driver): | |||
@@ -51,6 +51,9 @@ class TorchDriver(Driver): | |||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | |||
self.non_blocking = kwargs.get("torch_non_blocking", True) | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def zero_grad(self, set_to_none: bool = False): | |||
for optimizer in self.optimizers: | |||
self._clear_grad(optimizer, set_to_none) | |||
@@ -69,6 +72,14 @@ class TorchDriver(Driver): | |||
p.grad.requires_grad_(False) | |||
p.grad.zero_() | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
@@ -102,7 +113,7 @@ class TorchDriver(Driver): | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
@@ -191,9 +202,20 @@ class TorchDriver(Driver): | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
states['sampler_states'] = sampler.state_dict() | |||
sampler_states = sampler.state_dict() | |||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
try: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
@@ -252,7 +274,7 @@ class TorchDriver(Driver): | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = ReproducibleBatchSampler( | |||
sampler = RandomBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
@@ -8,6 +8,7 @@ import numpy as np | |||
import inspect | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@@ -140,24 +141,25 @@ class _DDPWrappingModel(Module): | |||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||
""" | |||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||
if _forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict): | |||
if forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif _forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif _forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif _forward_state == ForwardState.PREDICT: | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
@@ -294,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
"manually add the `DistributedSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||
) | |||
return type(dataloader)(**reconstruct_args) | |||
@@ -306,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||
""" | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
# checking the batch sampler type is different than PyTorch default. | |||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||
batch_sampler = type(batch_sampler)( | |||
sampler, | |||
batch_size=batch_sampler.batch_size, | |||
drop_last=batch_sampler.drop_last, | |||
) | |||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||
return { | |||
"sampler": None, | |||
@@ -342,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||
params = {k: getattr(dataloader, k) for k in params_keys} | |||
params["batch_sampler"] = new_batch_sampler | |||
return type(dataloader)(**params) | |||
# TODO 这里是否可以auto_param_call一下 | |||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||
# signature_fn=type(dataloader).__init__) | |||
def optimizer_state_to_device(state, device): | |||
@@ -51,6 +51,7 @@ class LoggerSingleton(type): | |||
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
def __init__(self, name): | |||
super().__init__(name) | |||
self._warning_msgs = set() | |||
def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | |||
mode: str = "w"): | |||
@@ -108,10 +109,25 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
def warning_once(self, msg, *args, **kwargs): | |||
""" | |||
通过 warning 内容只会 warning 一次 | |||
:param msg: | |||
:param args: | |||
:param kwargs: | |||
:return: | |||
""" | |||
if msg not in self._warning_msgs: | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
self._warning_msgs.add(msg) | |||
def warn(self, msg, *args, **kwargs): | |||
warnings.warn("The 'warn' method is deprecated, " | |||
"use 'warning' instead", DeprecationWarning, 2) | |||
self.warning(msg, *args, **kwargs) | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
def error(self, msg, *args, **kwargs): | |||
""" | |||
@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||
class Accuracy(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = True): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
@@ -64,7 +63,7 @@ class Accuracy(Metric): | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
@@ -23,14 +23,14 @@ __all__ = [ | |||
"BucketedBatchSampler", | |||
"ReproducibleBatchSampler", | |||
"re_instantiate_sampler", | |||
"conversion_between_reproducible_and_unrepeated_sampler" | |||
"re_instantiate_sampler" | |||
] | |||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from .utils import re_instantiate_sampler | |||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | |||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
@@ -0,0 +1,33 @@ | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ | |||
SortedSampler | |||
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ | |||
UnrepeatedSequentialSampler, UnrepeatedSortedSampler | |||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
""" | |||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
ReproducibleSampler, | |||
:param sampler: | |||
:return: | |||
""" | |||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
if isinstance(sampler, UnrepeatedSampler): | |||
if isinstance(sampler, UnrepeatedRandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
else: | |||
if isinstance(sampler, RandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
elif isinstance(sampler, SequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
elif isinstance(sampler, SortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no reproducible version.") |
@@ -4,16 +4,18 @@ __all__ = [ | |||
] | |||
import math | |||
from array import array | |||
from copy import deepcopy | |||
from typing import Dict, Union, List | |||
from itertools import chain | |||
import os | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.log import logger | |||
from .utils import create_array, NumConsumedSamplesArray | |||
from abc import abstractmethod | |||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
class ReproducibleBatchSampler: | |||
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||
@abstractmethod | |||
def state_dict(self): | |||
""" | |||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
:return: | |||
""" | |||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||
@abstractmethod | |||
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
self.batch_size = batch_size | |||
self.drop_last = drop_last | |||
self.data_idx = kwargs.get("data_idx", 0) | |||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||
self.index_list = kwargs.get("index_list", self._iterate_sampler()) | |||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||
else: | |||
_index_lst.append(idx) | |||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||
if len(_index_lst) > 4294967295: | |||
# 注意 self.index_list 内存放的是全部数据的 index; | |||
# unsigned long | |||
_index_lst = array("L", _index_lst) | |||
else: | |||
# unsigned int | |||
_index_lst = array("I", _index_lst) | |||
_index_lst = create_array(len(_index_lst), _index_lst) | |||
return _index_lst | |||
def __iter__(self): | |||
if self.need_reinitialize: | |||
self.index_list = self._iterate_sampler() | |||
self.data_idx = 0 | |||
self.num_consumed_samples = 0 | |||
else: | |||
self.need_reinitialize = True | |||
batch = [] | |||
if self.data_idx: | |||
index_list = self.index_list[self.data_idx:] | |||
if self.num_consumed_samples: | |||
index_list = self.index_list[self.num_consumed_samples:] | |||
else: | |||
index_list = self.index_list | |||
# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx in index_list: | |||
batch.append(idx) | |||
self.data_idx += 1 | |||
if len(batch) == self.batch_size: | |||
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield batch | |||
batch = [] | |||
if len(batch) > 0 and not self.drop_last: | |||
self.num_consumed_samples += len(batch) | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield batch | |||
# 需要重置防止边界条件问题 | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def __len__(self) -> int: | |||
if self.drop_last: | |||
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size | |||
def state_dict(self) -> Dict: | |||
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||
states = { | |||
"index_list": deepcopy(self.index_list), | |||
"num_consumed_samples": self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__ | |||
} | |||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
@@ -128,11 +147,11 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | |||
"record and current dataset." | |||
self.index_list = _index_list | |||
self.data_idx = states["data_idx"] | |||
self.num_consumed_samples = states["num_consumed_samples"] | |||
self.need_reinitialize = False | |||
def set_distributed(self, num_replicas, rank, pad=True): | |||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") | |||
def set_epoch(self, epoch): | |||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||
else: | |||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | |||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||
class BucketedBatchSampler(ReproducibleBatchSampler): | |||
@@ -166,8 +185,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__() | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
self.batch_size = batch_size | |||
self.num_batch_per_bucket = num_batch_per_bucket | |||
self.shuffle = shuffle | |||
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
self.rank = rank | |||
self.pad = pad | |||
num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
else len(self.dataset) | |||
if self.drop_last: | |||
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
"than the number of replicates multiplied " \ | |||
"with batch_size when drop_last=True." | |||
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
# else len(self.dataset) | |||
# | |||
# if self.drop_last: | |||
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
# "than the number of replicates multiplied " \ | |||
# "with batch_size when drop_last=True." | |||
return self | |||
@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
def __len__(self): | |||
def __len__(self)->int: | |||
""" | |||
返回当前 sampler 还会返回多少个 batch 的数据 | |||
@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||
batches = batches[:-1] | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for batch in batches: | |||
self.num_consumed_samples += self.num_replicas * len(batch) | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield list(map(int, batch)) | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
self.old_batch_size = self.batch_size | |||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||
self.old_num_replicas = self.num_replicas | |||
@@ -356,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
batch_indices = list(batch_indices[:-1]) | |||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||
batches = (np.array(batches)[batch_indices]).tolist() | |||
batches = (np.array(batches, dtype=object)[batch_indices]).tolist() | |||
if last_batches: | |||
batches = batches + last_batches | |||
return batches | |||
@@ -365,21 +387,16 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = { | |||
'seed': self.seed, | |||
'epoch': self.epoch, | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, | |||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas | |||
} | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas, | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -1,16 +1,21 @@ | |||
__all__ = [ | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
"SortedSampler", | |||
"SequentialSampler" | |||
] | |||
from typing import Dict, List, Union | |||
import math | |||
import os | |||
import numpy as np | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
from .utils import NumConsumedSamplesArray | |||
__all__ = [ | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
"SortedSampler", | |||
"SequentialSampler" | |||
] | |||
class ReproducibleSampler: | |||
@@ -30,6 +35,13 @@ class ReproducibleSampler: | |||
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | |||
def state_dict(self): | |||
""" | |||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
:return: | |||
""" | |||
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | |||
def load_state_dict(self, states): | |||
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||
indices = indices[self.num_consumed_samples:] | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def generate_indices(self) -> List[int]: | |||
""" | |||
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||
return indices | |||
def state_dict(self) -> Dict: | |||
states = { | |||
'seed': self.seed, | |||
'epoch': self.epoch, | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'shuffle': self.shuffle | |||
} | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
self.num_consumed_samples = 0 | |||
if self.shuffle != states['shuffle']: | |||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def generate_indices(self) -> List[int]: | |||
""" | |||
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||
return list(range(len(self.dataset))) | |||
def state_dict(self) -> Dict: | |||
states = { | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
} | |||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -295,8 +307,8 @@ class SortedSampler(SequentialSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -1,42 +1,65 @@ | |||
__all__ = [ | |||
're_instantiate_sampler', | |||
'conversion_between_reproducible_and_unrepeated_sampler' | |||
're_instantiate_sampler' | |||
] | |||
from array import array | |||
from typing import Sequence | |||
from collections import deque | |||
from fastNLP.core.samplers.unrepeated_sampler import * | |||
from fastNLP.core.samplers.reproducible_sampler import * | |||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
all_attributes = vars(sampler) | |||
if new_sampler_class is not None: | |||
return new_sampler_class(**all_attributes) | |||
return type(sampler)(**all_attributes) | |||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
def create_array(length, fill_value) -> array: | |||
""" | |||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
ReproducibleSampler, | |||
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||
:param sampler: | |||
:param length: | |||
:param fill_value: | |||
:return: | |||
""" | |||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
if isinstance(sampler, UnrepeatedSampler): | |||
if isinstance(sampler, UnrepeatedRandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
if not isinstance(fill_value, Sequence): | |||
fill_value = [fill_value]*length | |||
if length > 4294967295: | |||
_index_lst = array("L", fill_value) | |||
else: | |||
if isinstance(sampler, RandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
elif isinstance(sampler, SequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
elif isinstance(sampler, SortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
_index_lst = array("I", fill_value) | |||
return _index_lst | |||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
all_attributes = vars(sampler) | |||
if new_sampler_class is not None: | |||
return new_sampler_class(**all_attributes) | |||
return type(sampler)(**all_attributes) | |||
class NumConsumedSamplesArray: | |||
def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||
""" | |||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||
ex: | |||
array = NumConsumedSamplesArray(buffer_size=3) | |||
for i in range(10): | |||
array.push(i) | |||
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||
:param buffer_size: 报错多少个历史。 | |||
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||
""" | |||
self.count = 0 | |||
self.deque = deque(maxlen=buffer_size) | |||
if num_consumed_samples is not None: | |||
self.push(num_consumed_samples) | |||
self.buffer_size = buffer_size | |||
def __getitem__(self, item): | |||
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||
return 0 | |||
assert isinstance(item, int), "Only int index allowed." | |||
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||
index = len(self.deque) - (self.count - item) | |||
return self.deque[index] | |||
def push(self, num_consumed_samples): | |||
self.deque.append(num_consumed_samples) | |||
self.count += 1 |
@@ -13,7 +13,6 @@ __all__ = [ | |||
'torch_paddle_move_data_to_device', | |||
'torch_move_data_to_device', | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||
from .rich_progress import f_rich_progress | |||
from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]): | |||
device = device.lower() | |||
if device == "cpu": | |||
raise ValueError("Cannot get device id from `cpu`.") | |||
elif device == "gpu": | |||
return 0 | |||
match_res = re.match(r"gpu:\d+", device) | |||
if not match_res: | |||
raise ValueError( | |||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', " | |||
f"not '{device}'" | |||
) | |||
device_id = device.split(':', 1)[1] | |||
device_id = int(device_id) | |||
@@ -6,7 +6,7 @@ | |||
import sys | |||
from typing import Any, Union, Optional | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample | |||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
__all__ = [ | |||
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
self.refresh() # 使得bar不残留 | |||
def start(self) -> None: | |||
super().start() | |||
self.console.show_cursor(show=True) | |||
def update( | |||
self, | |||
task_id: TaskID, | |||
*, | |||
total: Optional[float] = None, | |||
completed: Optional[float] = None, | |||
advance: Optional[float] = None, | |||
description: Optional[str] = None, | |||
visible: Optional[bool] = None, | |||
refresh: bool = False, | |||
**fields: Any, | |||
) -> None: | |||
"""Update information associated with a task. | |||
Args: | |||
task_id (TaskID): Task id (returned by add_task). | |||
total (float, optional): Updates task.total if not None. | |||
completed (float, optional): Updates task.completed if not None. | |||
advance (float, optional): Add a value to task.completed if not None. | |||
description (str, optional): Change task description if not None. | |||
visible (bool, optional): Set visible flag if not None. | |||
refresh (bool): Force a refresh of progress information. Default is False. | |||
**fields (Any): Additional data fields required for rendering. | |||
""" | |||
with self._lock: | |||
task = self._tasks[task_id] | |||
completed_start = task.completed | |||
if total is not None and total != task.total: | |||
task.total = total | |||
task._reset() | |||
if advance is not None: | |||
task.completed += advance | |||
if completed is not None: | |||
task.completed = completed | |||
if description is not None: | |||
task.description = description | |||
if visible is not None: | |||
task.visible = visible | |||
task.fields.update(fields) | |||
update_completed = task.completed - completed_start | |||
current_time = self.get_time() | |||
old_sample_time = current_time - self.speed_estimate_period | |||
_progress = task._progress | |||
popleft = _progress.popleft | |||
# 这里修改为至少保留一个,防止超长时间的迭代影响判断 | |||
while len(_progress)>1 and _progress[0].timestamp < old_sample_time: | |||
popleft() | |||
if update_completed > 0: | |||
_progress.append(ProgressSample(current_time, update_completed)) | |||
if task.completed >= task.total and task.finished_time is None: | |||
task.finished_time = task.elapsed | |||
if refresh: | |||
self.refresh() | |||
class SpeedColumn(ProgressColumn): | |||
""" | |||
显示 task 的速度。 | |||
""" | |||
def render(self, task: "Task"): | |||
speed = task.speed | |||
if speed is None: | |||
return Text('-- it./s', style='progress.data.speed') | |||
if speed > 0.1: | |||
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') | |||
else: | |||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
BarColumn(), | |||
SpeedColumn(), | |||
TimeElapsedColumn(), | |||
"/", | |||
TimeRemainingColumn(), | |||
TextColumn("{task.fields[post_desc]}", justify="right"), | |||
transient=True, | |||
disable=False, | |||
speed_estimate_period=1 | |||
speed_estimate_period=30 | |||
) | |||
else: | |||
f_rich_progress = DummyFRichProgress() | |||
@@ -1,9 +1,11 @@ | |||
from abc import ABC | |||
from typing import Any, Union, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 | |||
DEFAULT_TORCH_GROUP = None | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
if not _TORCH_GREATER_EQUAL_1_8: | |||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||
__all__ = [ | |||
'torch_move_data_to_device' | |||
@@ -1,3 +1,4 @@ | |||
import functools | |||
import inspect | |||
from inspect import Parameter | |||
import dataclasses | |||
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
__all__ = [ | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
@@ -44,48 +43,23 @@ __all__ = [ | |||
] | |||
def get_fn_arg_names(fn: Callable) -> List[str]: | |||
r""" | |||
返回一个函数的所有参数的名字; | |||
:param fn: 需要查询的函数; | |||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||
""" | |||
return list(inspect.signature(fn).parameters) | |||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
r""" | |||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
""" | |||
if fn is None: | |||
return True | |||
if not callable(fn): | |||
return False | |||
else: | |||
params = inspect.signature(fn).parameters | |||
not_default_params = {} | |||
for _name, _param in params.items(): | |||
if _param.default == Parameter.empty: | |||
not_default_params[_name] = _param | |||
return len(not_default_params) == param_num | |||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
r""" | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
@@ -113,6 +87,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||
""" | |||
if signature_fn is not None: | |||
if not callable(signature_fn): | |||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||
@@ -122,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
_kwargs = None | |||
for _name, _param in _need_params.items(): | |||
if _param.kind == Parameter.VAR_POSITIONAL: | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") | |||
if _param.kind == Parameter.VAR_KEYWORD: | |||
_kwargs = (_name, _param) | |||
@@ -135,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
_default_params[_name] = _param.default | |||
if mapping is not None: | |||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ | |||
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
_has_params = {} | |||
duplicate_names = [] | |||
for arg in args: | |||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
if not isinstance(arg, Dict): | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise TypeError(f"Exception happens when calling {fn_msg}. " | |||
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") | |||
for _name, _value in arg.items(): | |||
if mapping is not None and _name in mapping: | |||
_name = mapping[_name] | |||
@@ -152,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
elif _name in _need_params and not (_has_params[_name] is _value): | |||
duplicate_names.append(_name) | |||
if duplicate_names: | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") | |||
# 将具有默认值但是没有被输入修改过的参数值传进去; | |||
for _name, _value in _default_params.items(): | |||
@@ -161,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
if len(_has_params)<len(_need_params): | |||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
_provided_keys = _get_keys(args) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} " | |||
f"are not found in the input keys({_provided_keys}).") | |||
return fn(**_has_params) | |||
def _get_keys(args:List[Dict]) -> List[List[str]]: | |||
""" | |||
返回每个 dict 的 keys | |||
:param args: | |||
:return: | |||
""" | |||
_provided_keys = [] | |||
for arg in args: | |||
_provided_keys.append(list(arg.keys())) | |||
return _provided_keys | |||
def _get_fun_msg(fn)->str: | |||
""" | |||
获取函数的基本信息,帮助报错。 | |||
ex: | |||
print(_get_fun_msg(_get_fun_msg)) | |||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||
:param callable fn: | |||
:return: | |||
""" | |||
if isinstance(fn, functools.partial): | |||
return _get_fun_msg(fn.func) | |||
try: | |||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||
except: | |||
fn_name = str(fn) | |||
try: | |||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||
except: | |||
fp = '' | |||
msg = f'`{fn_name}`' + fp | |||
return msg | |||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
""" | |||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||
进行报错。 | |||
:param fn: 需要检测的函数,可以是 method 或者 function 。 | |||
:param expected_params: 期待应该支持的参数。 | |||
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 | |||
:return: | |||
""" | |||
if fn_name is not None: | |||
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||
parameters = list(inspect.signature(fn).parameters.values()) | |||
if inspect.ismethod(fn): | |||
if len(parameters)>0 and parameters[0].name == 'self': | |||
parameters = parameters[1:] # 去掉self | |||
no_var_param = True # 没有 * 这种参数 | |||
number_param_need_value = 0 | |||
for param in parameters: | |||
if param.kind is param.VAR_POSITIONAL: | |||
no_var_param = False | |||
elif param.kind is param.VAR_KEYWORD: | |||
no_var_param = False | |||
else: | |||
if param.default is param.empty: | |||
number_param_need_value += 1 | |||
if len(parameters)<len(expected_params) and no_var_param: | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
if number_param_need_value>len(expected_params): | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||
""" | |||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||
@@ -184,7 +244,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||
return user_params | |||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | |||
if not is_dataclass(data): | |||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | |||
_dict = dict() | |||
@@ -591,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
wait_to_success(path.exists) | |||
def get_class_that_defined_method(method): | |||
""" | |||
给定一个method,返回这个 method 的 class 的对象 | |||
:param method: | |||
:return: | |||
""" | |||
if isinstance(method, functools.partial): | |||
return get_class_that_defined_method(method.func) | |||
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): | |||
for cls in inspect.getmro(method.__self__.__class__): | |||
if method.__name__ in cls.__dict__: | |||
return cls | |||
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing | |||
if inspect.isfunction(method): | |||
cls = getattr(inspect.getmodule(method), | |||
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
None) | |||
if isinstance(cls, type): | |||
return cls | |||
return getattr(method, '__objclass__', None) # handle special descriptor objects |
@@ -6,7 +6,8 @@ __all__ = [ | |||
'is_cur_env_distributed', | |||
'get_global_rank', | |||
'rank_zero_call', | |||
'all_rank_call' | |||
'all_rank_call', | |||
'get_gpu_count' | |||
] | |||
@@ -14,5 +15,5 @@ from .env import * | |||
from .set_env_on_import import set_env_on_import | |||
from .set_backend import dump_fastnlp_backend | |||
from .imports import * | |||
from .utils import _module_available | |||
from .utils import _module_available, get_gpu_count | |||
from .distributed import * |
@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||
# todo 注释 | |||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
# fastNLP 中初始化deque的默认大小 | |||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||
# todo 注释 直接使用的变量 | |||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||
@@ -5,13 +5,13 @@ | |||
import os | |||
import json | |||
import sys | |||
import subprocess | |||
from collections import defaultdict | |||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||
from fastNLP.envs.imports import SUPPORT_BACKENDS | |||
from fastNLP.envs.utils import _module_available | |||
from fastNLP.envs.utils import _module_available, get_gpu_count | |||
def _set_backend(): | |||
""" | |||
@@ -56,17 +56,18 @@ def _set_backend(): | |||
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | |||
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | |||
if user_visible_devices is not None and user_visible_devices != "": | |||
if user_visible_devices is not None: | |||
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | |||
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | |||
# 我们需要从中找到真正使用的设备编号 | |||
user_visible_devices = user_visible_devices.split(",") | |||
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | |||
else: | |||
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||
# 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备 | |||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||
range(get_gpu_count()) | |||
))) | |||
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) | |||
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | |||
@@ -78,7 +79,9 @@ def _set_backend(): | |||
else: | |||
# 没有设置的话限制在单卡上,防止多进程时占用别的卡 | |||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||
range(get_gpu_count()) | |||
))) | |||
elif backend == 'jittor': | |||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||
@@ -36,8 +36,7 @@ def set_env_on_import_torch(): | |||
# TODO paddle may need set this | |||
def set_env_on_import_paddle(): | |||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||
and "PADDLE_RANK_IN_NODE" in os.environ: | |||
# 检测到了分布式环境的环境变量 | |||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | |||
@@ -3,6 +3,7 @@ from typing import Callable | |||
import importlib | |||
from pkg_resources import DistributionNotFound | |||
from packaging.version import Version | |||
import subprocess | |||
import pkg_resources | |||
@@ -46,3 +47,15 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: | |||
if use_base_version: | |||
pkg_version = Version(pkg_version.base_version) | |||
return op(pkg_version, Version(version)) | |||
def get_gpu_count(): | |||
""" | |||
利用命令行获取gpu数目的函数 | |||
:return: gpu数目,如果没有显卡设备则为-1 | |||
""" | |||
try: | |||
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) | |||
# 经分割后还要除去头部和尾部的换行符 | |||
return len(lines.split(b"\n")) - 2 | |||
except: | |||
return -1 |
@@ -251,10 +251,10 @@ class DataBundle: | |||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
@@ -285,7 +285,7 @@ class DataBundle: | |||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
对DataBundle中所有的dataset使用apply方法 | |||
@@ -309,10 +309,10 @@ class DataBundle: | |||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||
def process_from_file(self, paths) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
:return: DataBundle | |||
@@ -164,7 +164,7 @@ class GraphBuilderBase: | |||
def build_graph_from_file(self, path: str): | |||
r""" | |||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param path: | |||
:return: scipy_sparse_matrix | |||
@@ -33,7 +33,7 @@ class Pipe: | |||
def process_from_file(self, paths: str) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param str paths: | |||
:return: DataBundle | |||
@@ -1,7 +1,7 @@ | |||
import pytest | |||
from functools import reduce | |||
from fastNLP.core.callbacks.callback_events import Filter | |||
from fastNLP.core.callbacks.callback_events import Events, Filter | |||
class TestFilter: | |||
@@ -10,7 +10,7 @@ import re | |||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
@@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( | |||
from fastNLP.core.callbacks.callback_events import Events | |||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||
@Trainer.on(Events.on_train_epoch_end) | |||
def raise_exception(trainer): | |||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||
raise NotImplementedError | |||
@@ -0,0 +1,93 @@ | |||
""" | |||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||
看看有没有用pytest执行的机会 | |||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||
""" | |||
import os | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
import sys | |||
sys.path.append("../../../") | |||
from dataclasses import dataclass | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.core.metrics.accuracy import Accuracy | |||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||
from fastNLP.core.callbacks import Callback | |||
import paddle | |||
from paddle.optimizer import Adam | |||
from paddle.io import DataLoader | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
@dataclass | |||
class MNISTTrainFleetConfig: | |||
num_labels: int = 10 | |||
feature_dimension: int = 10 | |||
batch_size: int = 32 | |||
shuffle: bool = True | |||
validate_every = -1 | |||
def test_trainer_fleet( | |||
driver, | |||
device, | |||
callbacks, | |||
n_epochs, | |||
): | |||
model = PaddleNormalModel_Classification_1( | |||
num_labels=MNISTTrainFleetConfig.num_labels, | |||
feature_dimension=MNISTTrainFleetConfig.feature_dimension | |||
) | |||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
train_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
val_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
train_dataloader = train_dataloader | |||
validate_dataloaders = val_dataloader | |||
validate_every = MNISTTrainFleetConfig.validate_every | |||
metrics = {"acc": Accuracy()} | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
validate_every=validate_every, | |||
input_mapping=None, | |||
output_mapping=None, | |||
metrics=metrics, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
output_from_new_proc="logs", | |||
) | |||
trainer.run() | |||
if __name__ == "__main__": | |||
driver = "fleet" | |||
device = [0,2,3] | |||
# driver = "paddle" | |||
# device = 2 | |||
callbacks = [ | |||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||
RichCallback(5), | |||
] | |||
test_trainer_fleet( | |||
driver=driver, | |||
device=device, | |||
callbacks=callbacks, | |||
n_epochs=5, | |||
) |
@@ -0,0 +1,98 @@ | |||
""" | |||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||
并且自己初始化了 fleet | |||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||
""" | |||
import os | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
import sys | |||
sys.path.append("../../../") | |||
from dataclasses import dataclass | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.core.metrics.accuracy import Accuracy | |||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||
from fastNLP.core.callbacks import Callback | |||
import paddle | |||
from paddle.optimizer import Adam | |||
from paddle.io import DataLoader | |||
import paddle.distributed.fleet as fleet | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | |||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
@dataclass | |||
class MNISTTrainFleetConfig: | |||
num_labels: int = 10 | |||
feature_dimension: int = 10 | |||
batch_size: int = 32 | |||
shuffle: bool = True | |||
validate_every = -1 | |||
def test_trainer_fleet( | |||
driver, | |||
device, | |||
callbacks, | |||
n_epochs, | |||
): | |||
fleet.init(is_collective=True) | |||
model = PaddleNormalModel_Classification_2( | |||
num_labels=MNISTTrainFleetConfig.num_labels, | |||
feature_dimension=MNISTTrainFleetConfig.feature_dimension, | |||
) | |||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
model = fleet.distributed_model(model) | |||
optimizers = fleet.distributed_optimizer(optimizers) | |||
train_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
val_dataloader = DataLoader( | |||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||
batch_size=MNISTTrainFleetConfig.batch_size, | |||
shuffle=True | |||
) | |||
train_dataloader = train_dataloader | |||
validate_dataloaders = val_dataloader | |||
validate_every = MNISTTrainFleetConfig.validate_every | |||
metrics = {"acc": Accuracy()} | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
validate_every=validate_every, | |||
input_mapping=None, | |||
output_mapping=None, | |||
metrics=metrics, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
output_from_new_proc="logs", | |||
data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}" | |||
) | |||
trainer.run() | |||
if __name__ == "__main__": | |||
driver = "fleet" | |||
device = [0,2,3] | |||
callbacks = [ | |||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||
RichCallback(5), | |||
] | |||
test_trainer_fleet( | |||
driver=driver, | |||
device=device, | |||
callbacks=callbacks, | |||
n_epochs=30, | |||
) |
@@ -0,0 +1,25 @@ | |||
import pytest | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.core.callbacks import Events | |||
from tests.helpers.utils import magic_argv_env_context | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator(): | |||
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10)) | |||
def fn1(trainer): | |||
pass | |||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||
def fn2(trainer, batch, indices): | |||
pass | |||
with pytest.raises(AssertionError): | |||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||
def fn3(trainer, batch): | |||
pass | |||
@@ -98,14 +98,16 @@ def model_and_optimizers(request): | |||
# 测试一下普通的情况; | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | |||
@pytest.mark.parametrize("validate_every", [-3]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_with_evaluator( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
callbacks, | |||
validate_every, | |||
n_epochs=10, | |||
): | |||
trainer = Trainer( | |||
@@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator( | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
validate_every=validate_every, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run() | |||
@@ -143,7 +145,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
accumulation_steps, | |||
n_epochs=6, | |||
): | |||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] | |||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_validate_every( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
n_epochs=6, | |||
): | |||
def validate_every(trainer): | |||
if trainer.global_forward_batches % 10 == 0: | |||
print(trainer) | |||
print("\nfastNLP test validate every.\n") | |||
print(trainer.global_forward_batches) | |||
return True | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
output_from_new_proc="all", | |||
validate_every=validate_every | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -10,7 +10,7 @@ from typing import Any | |||
from pathlib import Path | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||
@@ -70,7 +70,7 @@ def model_and_optimizers(request): | |||
trainer_params.output_mapping = None | |||
# elif request.param == 1: | |||
# model = | |||
return trainer_params | |||
@@ -254,7 +254,7 @@ def test_trainer_on_exception( | |||
): | |||
from fastNLP.core.callbacks.callback_events import Events | |||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||
@Trainer.on(Events.on_train_epoch_end) | |||
def raise_exception(trainer): | |||
if trainer.driver.get_local_rank() == cur_rank: | |||
raise NotImplementedError | |||
@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): | |||
subprocess.check_call(command) | |||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||
@magic_argv_env_context | |||
def test_torch_wo_auto_param_call( | |||
driver, | |||
device, | |||
n_epochs=10, | |||
): | |||
model = TorchNormalModel_Classification_3( | |||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||
) | |||
optimizers = SGD(model.parameters(), lr=0.001) | |||
dataset = TorchNormalDataset_Classification( | |||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||
seed=NormalClassificationTrainTorchConfig.seed | |||
) | |||
train_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||
shuffle=True | |||
) | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
n_epochs=n_epochs, | |||
model_wo_auto_param_call=True, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -1,83 +1,103 @@ | |||
import os | |||
import pytest | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | |||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
from fastNLP.envs import get_gpu_count | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.utils import magic_argv_env_context | |||
import paddle | |||
def test_incorrect_driver(): | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver("torch") | |||
driver = initialize_paddle_driver("torch", 0, model) | |||
@pytest.mark.parametrize( | |||
"device", | |||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||
["cpu", "gpu:0", 0, [1]] | |||
) | |||
def test_get_single_device(device): | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle"] | |||
) | |||
def test_get_single_device(driver, device): | |||
""" | |||
测试正常情况下初始化PaddleSingleDriver的情况 | |||
""" | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||
[0, 1] | |||
) | |||
def test_get_single_device_with_visiblde_devices(device): | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["fleet"] | |||
) | |||
@magic_argv_env_context | |||
def test_get_fleet_2(driver, device): | |||
""" | |||
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
# TODO | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[1, 2, 3]] | |||
[[0, 2, 3], -1] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle", "fleet"] | |||
) | |||
def test_get_fleet(device): | |||
@magic_argv_env_context | |||
def test_get_fleet(driver, device): | |||
""" | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[1,2,3]] | |||
("driver", "device"), | |||
[("fleet", "cpu")] | |||
) | |||
def test_get_fleet(device): | |||
@magic_argv_env_context | |||
def test_get_fleet_cpu(driver, device): | |||
""" | |||
测试 launch 启动 fleet 多卡的初始化情况 | |||
测试试图在 cpu 上初始化分布式训练的情况 | |||
""" | |||
# TODO | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver(driver, device, model) | |||
def test_device_out_of_range(device): | |||
@pytest.mark.parametrize( | |||
"device", | |||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle", "fleet"] | |||
) | |||
@magic_argv_env_context | |||
def test_device_out_of_range(driver, device): | |||
""" | |||
测试传入的device超过范围的情况 | |||
""" | |||
pass | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver(driver, device, model) |
@@ -1,262 +0,0 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||
import paddle | |||
from paddle.io import Dataset, DataLoader | |||
class Net(paddle.nn.Layer): | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.fc1 = paddle.nn.Linear(784, 64) | |||
self.fc2 = paddle.nn.Linear(64, 32) | |||
self.fc3 = paddle.nn.Linear(32, 10) | |||
self.fc4 = paddle.nn.Linear(10, 10) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
x = self.fc4(x) | |||
return x | |||
class PaddleDataset(Dataset): | |||
def __init__(self): | |||
super(PaddleDataset, self).__init__() | |||
self.items = [paddle.rand((3, 4)) for i in range(320)] | |||
def __len__(self): | |||
return len(self.items) | |||
def __getitem__(self, idx): | |||
return self.items[idx] | |||
class TorchNet(torch.nn.Module): | |||
def __init__(self): | |||
super(TorchNet, self).__init__() | |||
self.torch_fc1 = torch.nn.Linear(10, 10) | |||
self.torch_softmax = torch.nn.Softmax(0) | |||
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) | |||
self.torch_tensor = torch.ones(3, 3) | |||
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | |||
class TorchDataset(torch.utils.data.Dataset): | |||
def __init__(self): | |||
super(TorchDataset, self).__init__() | |||
self.items = [torch.ones(3, 4) for i in range(320)] | |||
def __len__(self): | |||
return len(self.items) | |||
def __getitem__(self, idx): | |||
return self.items[idx] | |||
class PaddleDriverTestCase(unittest.TestCase): | |||
""" | |||
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||
""" | |||
def setUp(self): | |||
model = Net() | |||
self.driver = PaddleDriver(model) | |||
def test_check_single_optimizer_legacy(self): | |||
""" | |||
测试传入单个optimizer时的表现 | |||
""" | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) | |||
# 传入torch的optimizer时,应该报错ValueError | |||
with self.assertRaises(ValueError) as cm: | |||
self.driver.set_optimizers(optimizer) | |||
def test_check_optimizers_legacy(self): | |||
""" | |||
测试传入optimizer list的表现 | |||
""" | |||
optimizers = [ | |||
paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
optimizers += [ | |||
torch.optim.Adam(TorchNet().parameters(), 0.01) | |||
] | |||
with self.assertRaises(ValueError) as cm: | |||
self.driver.set_optimizers(optimizers) | |||
def test_check_dataloader_legacy_in_train(self): | |||
""" | |||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||
""" | |||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
def test_check_dataloader_legacy_in_test(self): | |||
""" | |||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||
""" | |||
# 此时传入的应该是dict | |||
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 传入的不是dict,应该报错 | |||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 创建torch的dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
test_loader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
def test_tensor_to_numeric(self): | |||
""" | |||
测试tensor_to_numeric函数 | |||
""" | |||
# 单个张量 | |||
tensor = paddle.to_tensor(3) | |||
res = PaddleDriver.tensor_to_numeric(tensor) | |||
self.assertEqual(res, 3) | |||
tensor = paddle.rand((3, 4)) | |||
res = PaddleDriver.tensor_to_numeric(tensor) | |||
self.assertListEqual(res, tensor.tolist()) | |||
# 张量list | |||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = PaddleDriver.tensor_to_numeric(tensor_list) | |||
self.assertTrue(res, list) | |||
tensor_list = [t.tolist() for t in tensor_list] | |||
self.assertListEqual(res, tensor_list) | |||
# 张量tuple | |||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||
res = PaddleDriver.tensor_to_numeric(tensor_tuple) | |||
self.assertTrue(res, tuple) | |||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
self.assertTupleEqual(res, tensor_tuple) | |||
# 张量dict | |||
tensor_dict = { | |||
"tensor": paddle.rand((3, 4)), | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"dict":{ | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"tensor": paddle.rand((3, 4)) | |||
}, | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = PaddleDriver.tensor_to_numeric(tensor_dict) | |||
self.assertIsInstance(res, dict) | |||
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) | |||
self.assertIsInstance(res["list"], list) | |||
for r, d in zip(res["list"], tensor_dict["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) | |||
def test_set_model_mode(self): | |||
""" | |||
测试set_model_mode函数 | |||
""" | |||
self.driver.set_model_mode("train") | |||
self.assertTrue(self.driver.model.training) | |||
self.driver.set_model_mode("eval") | |||
self.assertFalse(self.driver.model.training) | |||
# 应该报错 | |||
with self.assertRaises(AssertionError) as cm: | |||
self.driver.set_model_mode("test") | |||
def test_move_model_to_device_cpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleDriver.move_model_to_device(self.driver.model, "cpu") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) | |||
def test_move_model_to_device_gpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) | |||
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) | |||
def test_worker_init_function(self): | |||
""" | |||
测试worker_init_function | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
PaddleDriver.worker_init_function(0) | |||
def test_set_deterministic_dataloader(self): | |||
""" | |||
测试set_deterministic_dataloader | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
self.driver.set_deterministic_dataloader(dataloader) | |||
def test_set_sampler_epoch(self): | |||
""" | |||
测试set_sampler_epoch | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
self.driver.set_sampler_epoch(dataloader, 0) | |||
def test_get_dataloader_args(self): | |||
""" | |||
测试get_dataloader_args | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
res = PaddleDriver.get_dataloader_args(dataloader) |
@@ -1,19 +1,19 @@ | |||
import os | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
import pytest | |||
from pathlib import Path | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from fastNLP.core import synchronize_safe_rm | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
from paddle.io import DataLoader, BatchSampler | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
from fastNLP.core.samplers import RandomBatchSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||
from fastNLP.core import synchronize_safe_rm | |||
import torch | |||
############################################################################ | |||
@@ -26,38 +26,116 @@ def generate_random_driver(features, labels): | |||
""" | |||
生成driver | |||
""" | |||
model = PaddleNormalModel_Classification(labels, features) | |||
model = PaddleNormalModel_Classification_1(labels, features) | |||
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | |||
driver = PaddleSingleDriver(model) | |||
driver = PaddleSingleDriver(model, device="cpu") | |||
driver.set_optimizers(opt) | |||
driver.setup() | |||
return driver | |||
@pytest.fixture | |||
def prepare_test_save_load(): | |||
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8) | |||
dataset = PaddleRandomMaxDataset(320, 10) | |||
dataloader = DataLoader(dataset, batch_size=32) | |||
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8) | |||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||
return driver1, driver2, dataloader | |||
def test_save_and_load(prepare_test_save_load): | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
def test_save_and_load_with_randombatchsampler(only_state_dict): | |||
""" | |||
测试save和load函数 | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
path = "model.ckp" | |||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||
dataset = PaddleRandomMaxDataset(80, 10) | |||
dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||
) | |||
# TODO 断点重训完善后在这里迭代几次 | |||
sampler_states = dataloader.batch_sampler.state_dict() | |||
if only_state_dict: | |||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||
else: | |||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
# 1. 检查 optimizer 的状态 | |||
# TODO optimizer 的 state_dict 总是为空 | |||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||
replaced_loader = states["dataloader"] | |||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] | |||
# 3. 检查 model 的参数是否被正确加载 | |||
for batch in dataloader: | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
driver1.save(path, {}) | |||
driver2.load(path) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
# 4. 检查 batch_idx | |||
# TODO | |||
finally: | |||
synchronize_safe_rm(path) | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
def test_save_and_load_with_randomsampler(only_state_dict): | |||
""" | |||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||
""" | |||
try: | |||
path = "model.ckp" | |||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||
dataset = PaddleRandomMaxDataset(80, 10) | |||
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) | |||
batch_sampler.sampler = RandomSampler(dataset, True) | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_sampler=batch_sampler | |||
) | |||
# TODO 断点重训完善后在这里迭代几次 | |||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||
if only_state_dict: | |||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||
else: | |||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
# 1. 检查 optimizer 的状态 | |||
# TODO optimizer 的 state_dict 总是为空 | |||
# 2. 检查 sampler 是否被正确地加载和替换 | |||
replaced_loader = states["dataloader"] | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] | |||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||
# 3. 检查 model 的参数是否被正确加载 | |||
for batch in dataloader: | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
# 4. 检查 batch_idx | |||
# TODO | |||
finally: | |||
synchronize_safe_rm(path) | |||
@@ -67,13 +145,14 @@ def test_save_and_load_state_dict(prepare_test_save_load): | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
path = "dict" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save_model(path) | |||
driver2.model.load_dict(driver2.load_model(path)) | |||
driver2.load_model(path) | |||
for batch in dataloader: | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
@@ -87,19 +166,22 @@ def test_save_and_load_whole_model(prepare_test_save_load): | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
path = "model" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) | |||
driver2.model = driver2.load_model(path, load_dict=False) | |||
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) | |||
driver2.load_model(path, only_state_dict=False) | |||
for batch in dataloader: | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
finally: | |||
synchronize_safe_rm(path) | |||
synchronize_safe_rm(path + ".pdiparams") | |||
synchronize_safe_rm(path + ".pdiparams.info") | |||
synchronize_safe_rm(path + ".pdmodel") | |||
class TestSingleDeviceFunction: | |||
@@ -109,8 +191,8 @@ class TestSingleDeviceFunction: | |||
@classmethod | |||
def setup_class(cls): | |||
model = PaddleNormalModel_Classification(10, 784) | |||
cls.driver = PaddleSingleDriver(model) | |||
model = PaddleNormalModel_Classification_1(10, 784) | |||
cls.driver = PaddleSingleDriver(model, device="cpu") | |||
def test_unwrap_model(self): | |||
""" | |||
@@ -125,22 +207,6 @@ class TestSingleDeviceFunction: | |||
self.driver.check_evaluator_mode("validate") | |||
self.driver.check_evaluator_mode("test") | |||
def test_get_model_device_cpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu") | |||
device = self.driver.get_model_device() | |||
assert device == "cpu", device | |||
def test_get_model_device_gpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0") | |||
device = self.driver.get_model_device() | |||
assert device == "gpu:0", device | |||
def test_is_distributed(self): | |||
assert self.driver.is_distributed() == False | |||
@@ -151,18 +217,420 @@ class TestSingleDeviceFunction: | |||
""" | |||
self.driver.move_data_to_device(paddle.rand((32, 64))) | |||
@pytest.mark.parametrize( | |||
"dist_sampler", | |||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
) | |||
@pytest.mark.parametrize( | |||
"reproducible", | |||
[True, False] | |||
) | |||
def test_repalce_sampler(self, dist_sampler, reproducible): | |||
class TestSetDistReproDataloder: | |||
""" | |||
专门测试 set_dist_repro_dataloader 函数的类 | |||
""" | |||
def setup_method(self): | |||
self.dataset = PaddleNormalDataset(20) | |||
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
self.driver = PaddleSingleDriver(model, device="cpu") | |||
def test_set_dist_repro_dataloader_with_reproducible_false(self): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||
当dist为字符串时,此时应该返回原来的 dataloader | |||
""" | |||
测试replace_sampler函数 | |||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) | |||
assert replaced_loader is self.dataloader | |||
def test_set_dist_repro_dataloader_with_reproducible_true(self): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | |||
""" | |||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) | |||
assert not (replaced_loader is self.dataloader) | |||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||
assert replaced_loader.drop_last == self.dataloader.drop_last | |||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||
""" | |||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||
""" | |||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||
assert not (replaced_loader is self.dataloader) | |||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
assert replaced_loader.batch_sampler is dist | |||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||
""" | |||
dist = RandomSampler(self.dataset, shuffle=True) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||
assert not (replaced_loader is self.dataloader) | |||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) | |||
assert replaced_loader.batch_sampler.sampler is dist | |||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||
""" | |||
dataloader = DataLoader( | |||
dataset=self.dataset, | |||
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||
) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
assert not (replaced_loader is dataloader) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
assert replaced_loader.drop_last == dataloader.drop_last | |||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||
""" | |||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
dataloader = DataLoader( | |||
self.dataset, | |||
batch_sampler=batch_sampler | |||
) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
assert not (replaced_loader is dataloader) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||
assert replaced_loader.batch_sampler.batch_size == 2 | |||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | |||
""" | |||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||
""" | |||
# 迭代两个 batch | |||
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 | |||
already_seen_idx = set() | |||
for idx, batch in replaced_loader: | |||
already_seen_idx.update(batch) | |||
if idx >= 1: | |||
break | |||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||
else: | |||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||
print(sampler_states["data_idx"]) | |||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | |||
left_idxes = set() | |||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||
replaced_loader.batch_sampler.load_state_dict(sampler_states) | |||
else: | |||
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
for idx, batch in enumerate(replaced_loader): | |||
left_idxes.update(batch) | |||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||
class TestPaddleDriverFunctions: | |||
""" | |||
使用 PaddleSingleDriver 测试基类的函数 | |||
""" | |||
@classmethod | |||
def setup_class(self): | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
self.driver = PaddleSingleDriver(model, device="cpu") | |||
def test_check_single_optimizer_legality(self): | |||
""" | |||
测试传入单个optimizer时的表现 | |||
""" | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
# 传入torch的optimizer时,应该报错ValueError | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizer) | |||
def test_check_optimizers_legality(self): | |||
""" | |||
测试传入optimizer list的表现 | |||
""" | |||
optimizers = [ | |||
paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
optimizers += [ | |||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
] | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizers) | |||
def test_check_dataloader_legality_in_train(self): | |||
""" | |||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||
""" | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
def test_check_dataloader_legality_in_test(self): | |||
""" | |||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||
""" | |||
# 此时传入的应该是dict | |||
dataloader = { | |||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||
} | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = { | |||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 传入的不是dict,应该报错 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 创建torch的dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
test_loader = torch.utils.data.DataLoader( | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
def test_tensor_to_numeric(self): | |||
""" | |||
测试tensor_to_numeric函数 | |||
""" | |||
# 单个张量 | |||
tensor = paddle.to_tensor(3) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||
assert res == 3 | |||
tensor = paddle.rand((3, 4)) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||
assert res == tensor.tolist() | |||
# 张量list | |||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) | |||
assert isinstance(res, list) | |||
tensor_list = [t.tolist() for t in tensor_list] | |||
assert res == tensor_list | |||
# 张量tuple | |||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) | |||
assert isinstance(res, tuple) | |||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
assert res == tensor_tuple | |||
# 张量dict | |||
tensor_dict = { | |||
"tensor": paddle.rand((3, 4)), | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"dict":{ | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"tensor": paddle.rand((3, 4)) | |||
}, | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) | |||
assert isinstance(res, dict) | |||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||
assert isinstance(res["list"], list) | |||
for r, d in zip(res["list"], tensor_dict["list"]): | |||
assert r == d.tolist() | |||
assert isinstance(res["int"], int) | |||
assert isinstance(res["string"], str) | |||
assert isinstance(res["dict"], dict) | |||
assert isinstance(res["dict"]["list"], list) | |||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
assert r == d.tolist() | |||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
def test_set_model_mode(self): | |||
""" | |||
测试set_model_mode函数 | |||
""" | |||
self.driver.set_model_mode("train") | |||
assert self.driver.model.training | |||
self.driver.set_model_mode("eval") | |||
assert not self.driver.model.training | |||
# 应该报错 | |||
with pytest.raises(AssertionError): | |||
self.driver.set_model_mode("test") | |||
def test_move_model_to_device_cpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||
assert self.driver.model.linear1.weight.place.is_cpu_place() | |||
def test_move_model_to_device_gpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||
assert self.driver.model.linear1.weight.place.is_gpu_place() | |||
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | |||
def test_worker_init_function(self): | |||
""" | |||
测试worker_init_function | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
PaddleSingleDriver.worker_init_function(0) | |||
def test_set_deterministic_dataloader(self): | |||
""" | |||
测试set_deterministic_dataloader | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
self.driver.set_deterministic_dataloader(dataloader) | |||
def test_set_sampler_epoch(self): | |||
""" | |||
测试set_sampler_epoch | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
self.driver.set_sampler_epoch(dataloader, 0) | |||
@pytest.mark.parametrize("batch_size", [16]) | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("drop_last", [True, False]) | |||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||
""" | |||
测试正常情况下 get_dataloader_args 的表现 | |||
""" | |||
dataloader = DataLoader( | |||
PaddleNormalDataset(), | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, | |||
) | |||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||
assert isinstance(res.dataset, PaddleNormalDataset) | |||
assert isinstance(res.batch_sampler, BatchSampler) | |||
if shuffle: | |||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||
else: | |||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||
assert res.shuffle == shuffle | |||
assert res.batch_size == batch_size | |||
assert res.drop_last == drop_last | |||
@pytest.mark.parametrize("batch_size", [16]) | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("drop_last", [True, False]) | |||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||
""" | |||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||
""" | |||
dataset = PaddleNormalDataset() | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_sampler=RandomBatchSampler( | |||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | |||
batch_size, | |||
drop_last, | |||
) | |||
) | |||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||
assert isinstance(res.dataset, PaddleNormalDataset) | |||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||
if shuffle: | |||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||
else: | |||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||
assert res.shuffle == shuffle | |||
assert res.batch_size == batch_size | |||
assert res.drop_last == drop_last | |||
@pytest.mark.parametrize("batch_size", [16]) | |||
@pytest.mark.parametrize("shuffle", [True, False]) | |||
@pytest.mark.parametrize("drop_last", [True, False]) | |||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||
""" | |||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||
""" | |||
dataset = PaddleNormalDataset() | |||
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) | |||
batch_sampler.sampler = RandomSampler(dataset, shuffle) | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_sampler=batch_sampler, | |||
) | |||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||
assert isinstance(res.dataset, PaddleNormalDataset) | |||
assert isinstance(res.batch_sampler, BatchSampler) | |||
assert isinstance(res.sampler, RandomSampler) | |||
assert res.shuffle == shuffle | |||
assert res.batch_size == batch_size | |||
assert res.drop_last == drop_last |
@@ -1,4 +1,56 @@ | |||
import unittest | |||
import os | |||
import pytest | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
from fastNLP.core.drivers.paddle_driver.utils import ( | |||
get_device_from_visible, | |||
replace_batch_sampler, | |||
replace_sampler, | |||
) | |||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||
import paddle | |||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler | |||
from paddle.io import DataLoader, BatchSampler | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||
@pytest.mark.parametrize( | |||
("user_visible_devices, cuda_visible_devices, device, output_type, correct"), | |||
( | |||
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), | |||
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), | |||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), | |||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), | |||
("3,4,5,6", "3,5", 0, int, 0), | |||
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), | |||
) | |||
) | |||
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): | |||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices | |||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices | |||
res = get_device_from_visible(device, output_type) | |||
assert res == correct | |||
def test_replace_batch_sampler(): | |||
dataset = PaddleNormalDataset(10) | |||
dataloader = DataLoader(dataset, batch_size=32) | |||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | |||
assert len(replaced_loader.dataset) == len(dataset) | |||
assert replaced_loader.batch_sampler.batch_size == 16 | |||
def test_replace_sampler(): | |||
dataset = PaddleNormalDataset(10) | |||
dataloader = DataLoader(dataset, batch_size=32) | |||
sampler = RandomSampler(dataset) | |||
replaced_loader = replace_sampler(dataloader, sampler) | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -6,13 +6,16 @@ import logging | |||
import re | |||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core.log.logger import logger | |||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
# 测试 TorchDDPDriver; | |||
@magic_argv_env_context | |||
def test_add_file_ddp_1(): | |||
@recover_logger | |||
def test_add_file_ddp_1_torch(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
@@ -56,11 +59,11 @@ def test_add_file_ddp_1(): | |||
synchronize_safe_rm(filepath) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_2(): | |||
@recover_logger | |||
def test_add_file_ddp_2_torch(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
@@ -103,14 +106,14 @@ def test_add_file_ddp_2(): | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@magic_argv_env_context | |||
def test_add_file_ddp_3(): | |||
@recover_logger | |||
def test_add_file_ddp_3_torch(): | |||
""" | |||
path = None; | |||
@@ -155,10 +158,10 @@ def test_add_file_ddp_3(): | |||
synchronize_safe_rm(file) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_4(): | |||
@recover_logger | |||
def test_add_file_ddp_4_torch(): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
@@ -200,7 +203,6 @@ def test_add_file_ddp_4(): | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -209,12 +211,11 @@ def test_add_file_ddp_4(): | |||
class TestLogger: | |||
msg = 'some test log msg' | |||
@recover_logger | |||
def test_add_file_1(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
filepath = path.joinpath('log.txt') | |||
@@ -225,14 +226,12 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_2(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
origin_path = Path(tempfile.mkdtemp()) | |||
try: | |||
@@ -245,14 +244,12 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(origin_path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_3(self): | |||
""" | |||
测试 path 是 None; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.add_file() | |||
logger.info(self.msg) | |||
@@ -264,14 +261,12 @@ class TestLogger: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
file.unlink() | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_4(self): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
handler = logger.add_file(path) | |||
@@ -285,16 +280,21 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_stdout(self, capsys): | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.set_stdout(stdout="raw") | |||
logger.info(self.msg) | |||
logger.debug('aabbc') | |||
captured = capsys.readouterr() | |||
assert "some test log msg\n" == captured.out | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_warning_once(self, capsys): | |||
logger.warning_once('#') | |||
logger.warning_once('#') | |||
logger.warning_once('@') | |||
captured = capsys.readouterr() | |||
assert captured.out.count('#') == 1 | |||
assert captured.out.count('@') == 1 | |||
@@ -3,6 +3,7 @@ from array import array | |||
import numpy as np | |||
import pytest | |||
from itertools import chain | |||
from copy import deepcopy | |||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||
_get_re_batchsampler = dataloader.batch_sampler | |||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
state = _get_re_batchsampler.state_dict() | |||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
"sampler_type": "RandomBatchSampler"} | |||
# 2. 断点重训,重新生成一个 dataloader; | |||
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||
@pytest.mark.parametrize('drop_last', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
batch_size = 6 | |||
if num_replica*batch_size > num_samples: | |||
if num_replicas*batch_size > num_samples: | |||
return | |||
num_batch_per_bucket = 10 | |||
samplers = [] | |||
lengths = [] | |||
for i in range(num_replica): | |||
for i in range(num_replicas): | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||
sampler.set_epoch(0) | |||
samplers.append(sampler) | |||
lengths.append(len(list(iter(sampler)))) | |||
assert len(set(lengths))==1 | |||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||
bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||
for bs in zip(*samplers): | |||
diff = max(chain(*bs)) - min(chain(*bs)) | |||
assert diff <= bucket_diff | |||
@pytest.mark.parametrize('shuffle', [True, False]) | |||
@pytest.mark.parametrize('drop_last', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
""" | |||
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||
偏多 | |||
:return: | |||
""" | |||
batch_size = 6 | |||
num_batch_per_bucket = 10 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
samplers = [] | |||
for i in range(num_replicas): | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
count = 0 | |||
already_seen_sets = [set()] | |||
already_seen_set = set() | |||
for batchs in zip(*samplers): | |||
batch = chain(*batchs) | |||
already_seen_set.update(batch) | |||
already_seen_sets.append(deepcopy(already_seen_set)) | |||
count += 1 | |||
if count > 3: | |||
break | |||
states = samplers[0].state_dict() | |||
for i in range(len(already_seen_sets)): | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.set_epoch(0) | |||
already_seen_set = deepcopy(already_seen_sets[i]) | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||
dataset) | |||
# 测试保存之后再次保存 | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.set_epoch(0) | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
if len(already_seen_sets)<3: | |||
return | |||
already_seen_set = already_seen_sets[2] | |||
count = 0 | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
count += 1 | |||
if count > 6: | |||
break | |||
states = sampler.state_dict() | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.load_state_dict(states) | |||
sampler.set_epoch(0) | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |
@@ -3,6 +3,7 @@ import pytest | |||
from functools import partial | |||
from itertools import chain | |||
from copy import deepcopy | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
@@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
@pytest.mark.parametrize('shuffle', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||
# 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
samplers = [] | |||
for i in range(num_replicas): | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
count = 0 | |||
already_seen_sets = [set()] | |||
already_seen_set = set() | |||
for idxes in zip(*samplers): | |||
already_seen_set.update(idxes) | |||
already_seen_sets.append(deepcopy(already_seen_set)) | |||
count += 1 | |||
if count > 3: | |||
break | |||
states = samplers[0].state_dict() | |||
for i in range(len(already_seen_sets)): | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
already_seen_set = deepcopy(already_seen_sets[i]) | |||
for batch in sampler: | |||
already_seen_set.add(batch) | |||
assert len(already_seen_set) == len(dataset) | |||
# 测试保存之后再次保存 | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.set_epoch(0) | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
if len(already_seen_sets)<3: | |||
return | |||
already_seen_set = already_seen_sets[2] | |||
count = 0 | |||
for idx in sampler: | |||
already_seen_set.add(idx) | |||
count += 1 | |||
if count > 6: | |||
break | |||
states = sampler.state_dict() | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.load_state_dict(states) | |||
sampler.set_epoch(0) | |||
for idx in sampler: | |||
already_seen_set.add(idx) | |||
assert len(already_seen_set)==len(dataset) | |||
class TestRandomSampler: | |||
# 测试单卡; | |||
@@ -386,7 +444,7 @@ class TestSortedSampler: | |||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -540,7 +598,7 @@ class TestSequentialSampler: | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||
indexes = set(sampler) | |||
assert indexes==set(range(num_of_data)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
@pytest.mark.parametrize('shuffle', [False, True]) | |||
def test_multi(self, num_replica, num_of_data, shuffle): | |||
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -0,0 +1,187 @@ | |||
from functools import partial | |||
import pytest | |||
from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg | |||
from fastNLP.core.metrics import Metric | |||
class TestAutoParamCall: | |||
def test_basic(self): | |||
def fn(x): | |||
return x | |||
x = {'x': 3, 'y': 4} | |||
r = auto_param_call(fn, x) | |||
assert r==3 | |||
xs = [] | |||
for i in range(10): | |||
xs.append({f'x{i}': i}) | |||
def fn(x0, x1, x2, x3): | |||
return x0 + x1 + x2 + x3 | |||
r = auto_param_call(fn, *xs) | |||
assert r == 0 + 1+ 2+ 3 | |||
def fn(chongfu1, chongfu2, buChongFu): | |||
pass | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, | |||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
assert 'The following key present in several inputs' in exc_info.value.args[0] | |||
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] | |||
# 没用到不报错 | |||
def fn(chongfu1, buChongFu): | |||
pass | |||
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, | |||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
# 可以定制signature_fn | |||
def fn1(**kwargs): | |||
kwargs.pop('x') | |||
kwargs.pop('y') | |||
assert len(kwargs)==0 | |||
def fn(x, y): | |||
pass | |||
x = {'x': 3, 'y': 4} | |||
r = auto_param_call(fn1, x, signature_fn=fn) | |||
# 没提供的时候报错 | |||
def fn(meiti1, meiti2, tigong): | |||
pass | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(fn, {'tigong':1}) | |||
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] | |||
# 默认值替换 | |||
def fn(x, y=100): | |||
return x + y | |||
r = auto_param_call(fn, {'x': 10, 'y': 20}) | |||
assert r==30 | |||
assert auto_param_call(fn, {'x': 10, 'z': 20})==110 | |||
# 测试mapping的使用 | |||
def fn(x, y=100): | |||
return x + y | |||
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) | |||
assert r==30 | |||
# 测试不需要任何参数 | |||
def fn(): | |||
return 1 | |||
assert 1 == auto_param_call(fn, {'x':1}) | |||
# 测试调用类的方法没问题 | |||
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) | |||
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) | |||
def test_msg(self): | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(self.call_this, {'x':1}) | |||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(call_this_for_auto_param_call, {'x':1}) | |||
assert __file__ in exc_info.value.args[0] | |||
assert 'call_this_for_auto_param_call' in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(self.call_this_two, {'x':1}) | |||
assert __file__ in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) | |||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 | |||
def call_this(self, x, y): | |||
return x + y | |||
def call_this_two(self, x, y, z=pytest, **kwargs): | |||
return x + y | |||
def test_metric_auto_param_call(self): | |||
metric = AutoParamCallMetric() | |||
with pytest.raises(BaseException): | |||
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) | |||
class AutoParamCallMetric(Metric): | |||
def update(self, x): | |||
pass | |||
def call_this_for_auto_param_call(x, y): | |||
return x + y | |||
class TestCheckNumberOfParameters: | |||
def test_validate_every(self): | |||
def validate_every(trainer): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
# 无默认值,多了报错 | |||
def validate_every(trainer, other): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
assert "2 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 有默认值ok | |||
def validate_every(trainer, other=1): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
# 参数多了 | |||
def validate_every(trainer): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 使用partial | |||
def validate_every(trainer, other): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 如果存在 *args 或 *kwargs 不报错多的 | |||
def validate_every(trainer, *args): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) | |||
def validate_every(trainer, **kwargs): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
# class 的方法删掉self | |||
class InnerClass: | |||
def demo(self, x): | |||
pass | |||
def no_param(self): | |||
pass | |||
def param_kwargs(self, **kwargs): | |||
pass | |||
inner = InnerClass() | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||
def test_get_fun_msg(): | |||
def demo(x): | |||
pass | |||
print(_get_fun_msg(_get_fun_msg)) |
@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||
def on_after_backward(self, trainer): | |||
print("on_after_backward") | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
print("on_before_optimizer_step") | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
print("on_before_optimizers_step") | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
print("on_after_optimizers_step") | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
print("on_before_zero_grad") | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
print("on_after_zero_grad") | |||
def on_validate_begin(self, trainer): | |||
print("on_validate_begin") | |||
@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||
x = torch.max(x, dim=-1)[1] | |||
return {"preds": x, "target": y} | |||
class TorchNormalModel_Classification_2(nn.Module): | |||
""" | |||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): | |||
return {"loss": loss, "preds": x, "target": y} | |||
class TorchNormalModel_Classification_3(nn.Module): | |||
""" | |||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||
关闭 auto_param_call,forward 只有一个 batch 参数; | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(TorchNormalModel_Classification_3, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||
self.ac1 = nn.ReLU() | |||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||
self.ac2 = nn.ReLU() | |||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
def forward(self, batch): | |||
x = batch["x"] | |||
y = batch["y"] | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
loss = self.loss_fn(x, y) | |||
x = torch.max(x, dim=-1)[1] | |||
return {"loss": loss, "preds": x, "target": y} | |||
@@ -2,34 +2,31 @@ import os | |||
import sys | |||
import __main__ | |||
from functools import wraps | |||
import inspect | |||
from inspect import ismethod | |||
import functools | |||
from copy import deepcopy | |||
from io import StringIO | |||
import time | |||
import numpy as np | |||
from fastNLP.core.utils.utils import get_class_that_defined_method | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.log import logger | |||
def get_class_that_defined_method(meth): | |||
if isinstance(meth, functools.partial): | |||
return get_class_that_defined_method(meth.func) | |||
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||
for cls in inspect.getmro(meth.__self__.__class__): | |||
if meth.__name__ in cls.__dict__: | |||
return cls | |||
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||
if inspect.isfunction(meth): | |||
cls = getattr(inspect.getmodule(meth), | |||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
None) | |||
if isinstance(cls, type): | |||
return cls | |||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||
def recover_logger(fn): | |||
@wraps(fn) | |||
def wrapper(*args, **kwargs): | |||
# 保存logger的状态 | |||
handlers = [handler for handler in logger.handlers] | |||
level = logger.level | |||
res = fn(*args, **kwargs) | |||
logger.handlers = handlers | |||
logger.setLevel(level) | |||
return res | |||
return wrapper | |||
def magic_argv_env_context(fn): | |||