@@ -10,6 +10,7 @@ from .utils import _get_monitor_value | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class Callback: | class Callback: | ||||
@@ -32,100 +33,225 @@ class Callback: | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
r""" | r""" | ||||
在 '预跑'检测 开始后会被触发; | 在 '预跑'检测 开始后会被触发; | ||||
:param trainer: | |||||
:param sanity_check_res: 预跑的 evaluate 结果 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练开始前会被触发; | 在训练开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
r""" | r""" | ||||
在训练完成后会被触发; | 在训练完成后会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 开始前会被触发; | 在训练过程中的每一个 epoch 开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_epoch_end(self, trainer): | def on_train_epoch_end(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 完成后会被触发; | |||||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_fetch_data_begin(self, trainer): | def on_fetch_data_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||||
在训练过程中准备取出下一个 batch 的数据时触发 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_fetch_data_end(self, trainer): | def on_fetch_data_end(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||||
在训练过程中拿到当前的 batch 数据后会被触发; | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_begin(self, trainer, batch, indices=None): | |||||
def on_train_batch_begin(self, trainer, batch, indices): | |||||
r""" | r""" | ||||
在训练过程中开始具体的一个 batch 前会被触发; | |||||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||||
:param trainer: `fastNLP.Trainer` | :param trainer: `fastNLP.Trainer` | ||||
:param batch: 当前正在运行的一个 batch; | |||||
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
""" | |||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||||
执行。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_exception(self, trainer, exception): | def on_exception(self, trainer, exception): | ||||
""" | |||||
在训练过程遇到异常时调用。 | |||||
:param trainer: | |||||
:param exception: 遭遇的异常。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_save_model(self, trainer): | def on_save_model(self, trainer): | ||||
""" | |||||
当将要保存模型时调用,此刻模型还未保存。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_load_model(self, trainer): | def on_load_model(self, trainer): | ||||
""" | |||||
当将要加载模型时调用,此刻模型还未加载。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||||
工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
r""" | r""" | ||||
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||||
的返回值。 | |||||
:param trainer: | |||||
:param states: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
""" | |||||
在 backward 前执行。 | |||||
:param trainer: | |||||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
""" | |||||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
""" | |||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
""" | |||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
""" | |||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
""" | |||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
""" | |||||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
""" | |||||
结束 validate 时调用,并把 validate 的结果传入。 | |||||
:param trainer: | |||||
:param results: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
@property | @property | ||||
def callback_name(self): | def callback_name(self): | ||||
""" | |||||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||||
:return: | |||||
""" | |||||
return self.__class__.__name__ | return self.__class__.__name__ | ||||
@@ -174,7 +300,11 @@ class HasMonitorCallback(Callback): | |||||
self.must_have_moinitor = must_have_monitor | self.must_have_moinitor = must_have_monitor | ||||
def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
self.monitor = str(monitor) if monitor is not None else None | |||||
if callable(monitor): # 检查是否能够接受一个参数 | |||||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||||
self.monitor = monitor | |||||
else: | |||||
self.monitor = str(monitor) if monitor is not None else None | |||||
self.larger_better = bool(larger_better) | self.larger_better = bool(larger_better) | ||||
if larger_better: | if larger_better: | ||||
self.monitor_value = float('-inf') | self.monitor_value = float('-inf') | ||||
@@ -197,24 +327,33 @@ class HasMonitorCallback(Callback): | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | ||||
f"You can set it in the initialization or through Trainer.") | f"You can set it in the initialization or through Trainer.") | ||||
def get_monitor_value(self, results:Dict)->float: | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||||
""" | """ | ||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | ||||
:param results: | :param results: | ||||
:return: | |||||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||||
""" | """ | ||||
if len(results)==0: | if len(results)==0: | ||||
return 0 | |||||
return None | |||||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | # 保证所有的 tensor 都被转换为了 python 特定的类型 | ||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | ||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | ||||
real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
res=results) | res=results) | ||||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||||
logger.warning( | |||||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
if monitor_value is None: | |||||
return monitor_value | |||||
# 第一次运行 | |||||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
# 检测到此次和上次不同。 | |||||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||||
f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||||
f"customized monitor function when the evaluation results are varying between validation.") | |||||
self._real_monitor = use_monitor | self._real_monitor = use_monitor | ||||
return monitor_value | return monitor_value | ||||
@@ -222,14 +361,33 @@ class HasMonitorCallback(Callback): | |||||
""" | """ | ||||
检测 monitor_value 是否是更好的 | 检测 monitor_value 是否是更好的 | ||||
:param monitor_value: | |||||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | ||||
:return: | :return: | ||||
""" | """ | ||||
if monitor_value is None: | |||||
return False | |||||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||||
if keep_if_better and better: | |||||
self.monitor_value = monitor_value | |||||
return better | |||||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||||
""" | |||||
传入的两个值中,是否monitor_value1的结果更好。 | |||||
:param monitor_value1: | |||||
:param monitor_value2: | |||||
:return: | |||||
""" | |||||
if monitor_value1 is None and monitor_value2 is None: | |||||
return True | |||||
if monitor_value1 is None: | |||||
return False | |||||
if monitor_value2 is None: | |||||
return True | |||||
better = False | better = False | ||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||||
(not self.larger_better and monitor_value1 < monitor_value2): | |||||
better = True | better = True | ||||
if keep_if_better: | |||||
self.monitor_value = monitor_value | |||||
return better | return better |
@@ -74,28 +74,30 @@ class EventEnum(_SingleEventState, Enum): | |||||
@unique | @unique | ||||
class Events(EventEnum): | class Events(EventEnum): | ||||
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" | |||||
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" | |||||
ON_SANITY_CHECK_END = "on_sanity_check_end" | |||||
ON_TRAIN_BEGIN = "on_train_begin" | |||||
ON_TRAIN_END = "on_train_end" | |||||
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" | |||||
ON_TRAIN_EPOCH_END = "on_train_epoch_end" | |||||
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" | |||||
ON_FETCH_DATA_END = "on_fetch_data_end" | |||||
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" | |||||
ON_TRAIN_BATCH_END = "on_train_batch_end" | |||||
ON_EXCEPTION = "on_exception" | |||||
ON_SAVE_MODEL = "on_save_model" | |||||
ON_LOAD_MODEL = "on_load_model" | |||||
ON_SAVE_CHECKPOINT = "on_save_checkpoint" | |||||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | |||||
ON_BEFORE_BACKWARD = "on_before_backward" | |||||
ON_AFTER_BACKWARD = "on_after_backward" | |||||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | |||||
ON_VALIDATE_BEGIN = "on_validate_begin" | |||||
ON_VALIDATE_END = "on_validate_end" | |||||
on_after_trainer_initialized = "on_after_trainer_initialized" | |||||
on_sanity_check_begin = "on_sanity_check_begin" | |||||
on_sanity_check_end = "on_sanity_check_end" | |||||
on_train_begin = "on_train_begin" | |||||
on_train_end = "on_train_end" | |||||
on_train_epoch_begin = "on_train_epoch_begin" | |||||
on_train_epoch_end = "on_train_epoch_end" | |||||
on_fetch_data_begin = "on_fetch_data_begin" | |||||
on_fetch_data_end = "on_fetch_data_end" | |||||
on_train_batch_begin = "on_train_batch_begin" | |||||
on_train_batch_end = "on_train_batch_end" | |||||
on_exception = "on_exception" | |||||
on_save_model = "on_save_model" | |||||
on_load_model = "on_load_model" | |||||
on_save_checkpoint = "on_save_checkpoint" | |||||
on_load_checkpoint = "on_load_checkpoint" | |||||
on_before_backward = "on_before_backward" | |||||
on_after_backward = "on_after_backward" | |||||
on_before_optimizers_step = "on_before_optimizers_step" | |||||
on_after_optimizers_step = "on_after_optimizers_step" | |||||
on_before_zero_grad = "on_before_zero_grad" | |||||
on_after_zero_grad = "on_after_zero_grad" | |||||
on_validate_begin = "on_validate_begin" | |||||
on_validate_end = "on_validate_end" | |||||
class EventsList: | class EventsList: | ||||
@@ -169,20 +171,8 @@ class Filter: | |||||
self.num_called += 1 | self.num_called += 1 | ||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | ||||
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; | |||||
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; | |||||
# 提取参数的逻辑; | |||||
trainer = kwargs.get("trainer", None) | |||||
if trainer is None and len(args) > 0: | |||||
trainer = args[0] | |||||
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer | |||||
# 也会调用这个 module,但是 Controller 不会; | |||||
param = (self, trainer) | |||||
else: | |||||
param = (self, ) | |||||
if self._filter(*param): | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | self.num_executed += 1 | ||||
return fn(*args, **kwargs) | return fn(*args, **kwargs) | ||||
@@ -278,13 +278,21 @@ class CallbackManager: | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
pass | pass | ||||
@_transfer | |||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | @_transfer | ||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
pass | pass | ||||
@@ -10,12 +10,10 @@ from copy import deepcopy | |||||
import fastNLP | import fastNLP | ||||
from .callback import Callback, HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | ||||
from fastNLP.core.utils import apply_to_collection | |||||
class CheckpointCallback(HasMonitorCallback): | class CheckpointCallback(HasMonitorCallback): | ||||
@@ -167,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): | |||||
""" | """ | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
monitor_value = self.get_monitor_value(results=results) | monitor_value = self.get_monitor_value(results=results) | ||||
if monitor_value is None: | |||||
return | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | ||||
f"-{self._real_monitor}_{monitor_value}" | f"-{self._real_monitor}_{monitor_value}" | ||||
@@ -178,8 +178,7 @@ class CheckpointCallback(HasMonitorCallback): | |||||
else: | else: | ||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | ||||
key=lambda x: self._topk_model[x]) | key=lambda x: self._topk_model[x]) | ||||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = monitor_value | self._topk_model[folder_name] = monitor_value | ||||
_should_save = True | _should_save = True | ||||
self._topk_model.pop(_least_valuable_model) | self._topk_model.pop(_least_valuable_model) | ||||
@@ -208,21 +207,6 @@ class CheckpointCallback(HasMonitorCallback): | |||||
**self.kwargs | **self.kwargs | ||||
) | ) | ||||
def _get_validate_metric(self, res: Dict): | |||||
""" | |||||
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||||
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||||
:param res: | |||||
:return: | |||||
""" | |||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||||
if self._real_monitor != use_monitor: | |||||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | |||||
return value | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
raise NotImplementedError("The `folder_prefix` is not specified") | raise NotImplementedError("The `folder_prefix` is not specified") | ||||
@@ -248,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -295,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -2,17 +2,18 @@ __all__ = [ | |||||
'EarlyStopCallback' | 'EarlyStopCallback' | ||||
] | ] | ||||
from typing import Dict | |||||
from typing import Dict, Union, Callable | |||||
from .callback import HasMonitorCallback | from .callback import HasMonitorCallback | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
class EarlyStopCallback(HasMonitorCallback): | class EarlyStopCallback(HasMonitorCallback): | ||||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 validate 不没有提升就停止。 | :param patience: 多少次 validate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
self.patience = patience | self.patience = patience | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
if monitor_value is None: | |||||
return | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | if self.is_better_monitor_value(monitor_value, keep_if_better=True): | ||||
self.wait = 0 | self.wait = 0 | ||||
else: | else: | ||||
@@ -3,7 +3,7 @@ __all__ = [ | |||||
] | ] | ||||
import os | import os | ||||
from typing import Optional, Callable | |||||
from typing import Optional, Callable, Union | |||||
from .callback import HasMonitorCallback | from .callback import HasMonitorCallback | ||||
from io import BytesIO | from io import BytesIO | ||||
import shutil | import shutil | ||||
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call | |||||
class LoadBestModelCallback(HasMonitorCallback): | class LoadBestModelCallback(HasMonitorCallback): | ||||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
model_load_fn:Optional[Callable] = None, | model_load_fn:Optional[Callable] = None, | ||||
delete_after_train:bool = True): | delete_after_train:bool = True): | ||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
self.get_monitor_value(sanity_check_res) | self.get_monitor_value(sanity_check_res) | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
if monitor_value is None: | |||||
return | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | if self.is_better_monitor_value(monitor_value, keep_if_better=True): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | ||||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -1,9 +1,10 @@ | |||||
from typing import Optional | |||||
from typing import Optional, Union | |||||
from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
from difflib import SequenceMatcher | from difflib import SequenceMatcher | ||||
from fastNLP.core.utils.utils import _get_fun_msg | |||||
def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||||
""" | """ | ||||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | ||||
匹配。 | 匹配。 | ||||
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||||
:param monitor: | :param monitor: | ||||
:param real_monitor: | :param real_monitor: | ||||
:param res: | :param res: | ||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||||
找到对应的 monitor | |||||
""" | """ | ||||
if len(res)==0: | if len(res)==0: | ||||
return monitor, 0 | |||||
return monitor, None | |||||
if callable(monitor): | |||||
try: | |||||
monitor_value = monitor(res) | |||||
except BaseException as e: | |||||
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") | |||||
raise e | |||||
return monitor, monitor_value | |||||
if monitor in res: | if monitor in res: | ||||
return monitor, res[monitor] | return monitor, res[monitor] | ||||
@@ -5,7 +5,7 @@ __all__ = [ | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from typing import Any, Dict, List, Callable, Union | |||||
from typing import Any, Dict, List, Callable, Union, Tuple | |||||
from numbers import Number | from numbers import Number | ||||
import warnings | import warnings | ||||
@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): | |||||
self.field_name = field_name # 标示当前 field 的名称 | self.field_name = field_name # 标示当前 field 的名称 | ||||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||||
r""" | r""" | ||||
识别cell的类别与dimension的数量 | 识别cell的类别与dimension的数量 | ||||
@@ -197,7 +197,7 @@ class _MultiCollator: | |||||
collator.set_input(*field_names) | collator.set_input(*field_names) | ||||
flag = False | flag = False | ||||
if flag: | if flag: | ||||
warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||||
return self | return self | ||||
@@ -206,7 +206,7 @@ class AutoCollator(Collator): | |||||
def __init__(self, as_numpy: bool): | def __init__(self, as_numpy: bool): | ||||
super(AutoCollator, self).__init__() | super(AutoCollator, self).__init__() | ||||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | ||||
self.need_inputs = [] # 需要的 field name | |||||
self.need_inputs = set() # 需要的 field name | |||||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | self.field_dtypes = None # 每列数据单元的 dtype 类型 | ||||
self.field_dims = None # 每列数据单元维度 | self.field_dims = None # 每列数据单元维度 | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
@@ -214,10 +214,17 @@ class AutoCollator(Collator): | |||||
def __call__(self, ins_lst: List[Dict]) -> dict: | def __call__(self, ins_lst: List[Dict]) -> dict: | ||||
if len(self.need_inputs) == 0: | if len(self.need_inputs) == 0: | ||||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | ||||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||||
# 第一种情况,设置了 set_input 的值 | # 第一种情况,设置了 set_input 的值 | ||||
# 第二种情况, 根据数据的类型的判断是否 padding | # 第二种情况, 根据数据的类型的判断是否 padding | ||||
if self.field_dtypes is None and self.field_dims is None: | if self.field_dtypes is None and self.field_dims is None: | ||||
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||||
field_dtypes, field_dims = {}, {} | |||||
for key, value in ins_lst[0].items(): | |||||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||||
self.field_dtypes = field_dtypes | |||||
self.field_dims = field_dims | |||||
pack_ins_lst, pad_ins_lst = {field_name: [] | pack_ins_lst, pad_ins_lst = {field_name: [] | ||||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | ||||
@@ -233,13 +240,13 @@ class AutoCollator(Collator): | |||||
if len(self.pad_field_value.keys()) > 0: | if len(self.pad_field_value.keys()) > 0: | ||||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | ||||
drop_field_names = [] | |||||
non_pad_field_names = [] | |||||
for k, v in self.pad_field_value.items(): | for k, v in self.pad_field_value.items(): | ||||
if v is None: | if v is None: | ||||
drop_field_names.append(k) | |||||
non_pad_field_names.append(k) | |||||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | ||||
for field_name in drop_field_names: | |||||
for field_name in non_pad_field_names: | |||||
field_array = pack_ins_lst.pop(field_name) | field_array = pack_ins_lst.pop(field_name) | ||||
pad_ins_lst[field_name] = np.array(field_array) | pad_ins_lst[field_name] = np.array(field_array) | ||||
@@ -269,7 +276,7 @@ class AutoCollator(Collator): | |||||
def set_input(self, *field_names): | def set_input(self, *field_names): | ||||
for field_name in field_names: | for field_name in field_names: | ||||
self.need_inputs.append(field_name) | |||||
self.need_inputs.add(field_name) | |||||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | ||||
@@ -11,11 +11,12 @@ __all__ = [ | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||||
match_and_substitute_params, f_rich_progress | match_and_substitute_params, f_rich_progress | ||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -38,10 +39,11 @@ class Evaluator: | |||||
driver: Union[str, Driver] = 'single', | driver: Union[str, Driver] = 'single', | ||||
device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
mode: str = "validate", | |||||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
fp16: Optional[bool] = False, | |||||
model_wo_auto_param_call: bool = False, | |||||
fp16: bool = False, | |||||
verbose: int = 1, | verbose: int = 1, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -61,6 +63,9 @@ class Evaluator: | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | ||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | ||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
:param fp16: 是否使用 fp16 。 | :param fp16: 是否使用 fp16 。 | ||||
:param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
:param kwargs: | :param kwargs: | ||||
@@ -83,13 +88,13 @@ class Evaluator: | |||||
self.model = model | self.model = model | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) | |||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||||
self.device = device | self.device = device | ||||
self.verbose = verbose | self.verbose = verbose | ||||
assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
if batch_step_fn is not None: | |||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
self.mode = mode | self.mode = mode | ||||
@@ -131,6 +136,7 @@ class Evaluator: | |||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | ||||
self.driver.check_evaluator_mode(self.mode) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | ||||
@@ -150,8 +156,6 @@ class Evaluator: | |||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | ||||
self.driver.check_evaluator_mode(self.mode) | |||||
if self.mode == 'validate': | if self.mode == 'validate': | ||||
assert self.driver.has_validate_dataloaders() | assert self.driver.has_validate_dataloaders() | ||||
else: | else: | ||||
@@ -219,7 +223,6 @@ class Evaluator: | |||||
def remove_progress_bar(self, dataloader_name): | def remove_progress_bar(self, dataloader_name): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
desc = 'Evaluation ends' | desc = 'Evaluation ends' | ||||
@@ -230,7 +233,6 @@ class Evaluator: | |||||
def finally_progress_bar(self): | def finally_progress_bar(self): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
@property | @property | ||||
@@ -355,20 +357,24 @@ class _MetricsWrapper: | |||||
if is_dataclass(outputs): | if is_dataclass(outputs): | ||||
outputs = dataclass_to_dict(outputs) | outputs = dataclass_to_dict(outputs) | ||||
for metric in self._metrics: | for metric in self._metrics: | ||||
args = [] | |||||
if not isinstance(batch, dict): | if not isinstance(batch, dict): | ||||
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||||
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||||
f"the output of model to update metric.") | |||||
else: | |||||
args.append(batch) | |||||
if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | f" return a dict from your model or use `output_mapping` to convert it into dict type.") | ||||
if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
auto_param_call(metric.update, batch, outputs) | |||||
# 这样在 auto_param_call 报错的时候才清晰。 | |||||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||||
elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
auto_param_call(metric.update, batch, outputs) | |||||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||||
elif _is_allennlp_metric(metric): | elif _is_allennlp_metric(metric): | ||||
auto_param_call(metric.__call__, batch, outputs) | |||||
auto_param_call(metric.__call__, outputs, *args) | |||||
elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
res = auto_param_call(metric.compute, batch, outputs) | |||||
res = auto_param_call(metric.compute, outputs, *args) | |||||
metric.update(res) | metric.update(res) | ||||
def reset(self): | def reset(self): | ||||
@@ -7,6 +7,7 @@ from typing import Optional, Callable | |||||
from .loop import Loop | from .loop import Loop | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import match_and_substitute_params | from fastNLP.core.utils import match_and_substitute_params | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class TrainBatchLoop(Loop): | class TrainBatchLoop(Loop): | ||||
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): | |||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
batch = trainer.move_data_to_device(batch) | |||||
trainer.on_fetch_data_end() | trainer.on_fetch_data_end() | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||||
batch = trainer.move_data_to_device(batch) | |||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
except BaseException as e: # TODO 把这里的信息写入进去 | |||||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||||
break | |||||
except BaseException as e: | |||||
if indices: | if indices: | ||||
logger.debug(f"The following exception happens when running on samples: {indices}") | logger.debug(f"The following exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
@@ -14,6 +14,7 @@ __all__ = [ | |||||
from .loops import Loop, TrainBatchLoop | from .loops import Loop, TrainBatchLoop | ||||
from .utils import State, TrainerState | from .utils import State, TrainerState | ||||
from .utils.utils import check_validate_every | |||||
from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | ||||
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
@@ -42,15 +44,16 @@ class Trainer(TrainerEventTrigger): | |||||
validate_dataloaders=None, | validate_dataloaders=None, | ||||
batch_step_fn: Optional[Callable] = None, | batch_step_fn: Optional[Callable] = None, | ||||
validate_batch_step_fn: Optional[Callable] = None, | validate_batch_step_fn: Optional[Callable] = None, | ||||
validate_mode: str = "validate", | |||||
validate_mode: Union[str, callable] = 'validate', | |||||
callbacks: Union[List[Callback], Callback, None] = None, | callbacks: Union[List[Callback], Callback, None] = None, | ||||
metrics: Optional[dict] = None, | metrics: Optional[dict] = None, | ||||
validate_every: Optional[Union[int, callable]] = -1, | validate_every: Optional[Union[int, callable]] = -1, | ||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | |||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
monitor: str = None, | |||||
monitor: Union[str, callable] = None, | |||||
larger_better: bool = True, | larger_better: bool = True, | ||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
@@ -89,11 +92,8 @@ class Trainer(TrainerEventTrigger): | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | ||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||||
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||||
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||||
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||||
之后,每个 epoch 结束进行一次 validate 。 | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | ||||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | ||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | ||||
@@ -102,12 +102,15 @@ class Trainer(TrainerEventTrigger): | |||||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | ||||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | ||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | ||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
:param kwargs: 一些其它的可能需要的参数; | :param kwargs: 一些其它的可能需要的参数; | ||||
@@ -126,20 +129,21 @@ class Trainer(TrainerEventTrigger): | |||||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | ||||
""" | """ | ||||
# TODO 是不是可以加一个参数让用户现在关掉参数匹配。 | |||||
self.marker = marker | |||||
self.model = model | self.model = model | ||||
self.driver_name = driver | |||||
self.marker = marker | |||||
if isinstance(driver, str): | |||||
self.driver_name = driver | |||||
else: | |||||
self.driver_name = driver.__class__.__name__ | |||||
self.device = device | self.device = device | ||||
self.optimizers = optimizers | |||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||||
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | ||||
else: | else: | ||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | self.check_batch_step_fn = lambda *args, **kwargs: ... | ||||
@@ -155,6 +159,8 @@ class Trainer(TrainerEventTrigger): | |||||
elif accumulation_steps < 0: | elif accumulation_steps < 0: | ||||
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") | raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") | ||||
self.accumulation_steps = accumulation_steps | self.accumulation_steps = accumulation_steps | ||||
# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧 | |||||
self.driver = choose_driver( | self.driver = choose_driver( | ||||
model=model, | model=model, | ||||
driver=driver, | driver=driver, | ||||
@@ -171,6 +177,7 @@ class Trainer(TrainerEventTrigger): | |||||
validate_every=validate_every, | validate_every=validate_every, | ||||
input_mapping=input_mapping, | input_mapping=input_mapping, | ||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
model_wo_auto_param_call=model_wo_auto_param_call, | |||||
accumulation_steps=accumulation_steps, | accumulation_steps=accumulation_steps, | ||||
fp16=fp16, | fp16=fp16, | ||||
marker=marker, | marker=marker, | ||||
@@ -212,17 +219,11 @@ class Trainer(TrainerEventTrigger): | |||||
if metrics is not None and validate_dataloaders is None: | if metrics is not None and validate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | ||||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||||
self.evaluator = None | self.evaluator = None | ||||
self.epoch_validate = lambda *args, **kwargs: ... | |||||
self.step_validate = lambda *args, **kwargs: ... | |||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and validate_dataloaders is not None: | if metrics is not None and validate_dataloaders is not None: | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
check_validate_every(validate_every) | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
dataloaders=validate_dataloaders, | dataloaders=validate_dataloaders, | ||||
@@ -239,16 +240,6 @@ class Trainer(TrainerEventTrigger): | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | progress_bar=kwargs.get('progress_bar', 'auto') | ||||
) | ) | ||||
if callable(validate_every): | |||||
self._step_validate_filter = Filter(filter_fn=validate_every) | |||||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||||
elif validate_every < 0: | |||||
self._epoch_validate_filter = Filter(every=-validate_every) | |||||
else: | |||||
# validate_every > 0 | |||||
self._step_validate_filter = Filter(every=validate_every) | |||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = validate_every | self.validate_every = validate_every | ||||
@@ -317,6 +308,8 @@ class Trainer(TrainerEventTrigger): | |||||
try: | try: | ||||
while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
self.on_train_epoch_begin() | self.on_train_epoch_begin() | ||||
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | ||||
@@ -345,31 +338,37 @@ class Trainer(TrainerEventTrigger): | |||||
raise e | raise e | ||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | ||||
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: | |||||
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: | |||||
trainer.on_validate_begin() | trainer.on_validate_begin() | ||||
_validate_res: dict = validate_fn() | _validate_res: dict = validate_fn() | ||||
trainer.on_validate_end(_validate_res) | trainer.on_validate_end(_validate_res) | ||||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
def step_validate(self): | |||||
""" | |||||
在每个 batch 结束后调用,根据设置执行 evaluate 。 | |||||
:return: | |||||
""" | |||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
if callable(self.validate_every): | if callable(self.validate_every): | ||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
elif self.validate_every < 0: | |||||
self.epoch_validate = self._epoch_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
else: | |||||
# validate_every > 0 | |||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
if self.validate_every(self): | |||||
self.run_evaluate() | |||||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||||
self.run_evaluate() | |||||
def epoch_validate(self): | |||||
""" | |||||
在每个 epoch 结束后调用,根据设置执行 evaluate 。 | |||||
:return: | |||||
""" | |||||
if self.evaluator is not None: | |||||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||||
validate_every = -self.validate_every | |||||
if self.cur_epoch_idx % validate_every == 0: | |||||
self.run_evaluate() | |||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | ||||
r""" | r""" | ||||
@@ -400,9 +399,8 @@ class Trainer(TrainerEventTrigger): | |||||
def wrapper(fn: Callable) -> Callable: | def wrapper(fn: Callable) -> Callable: | ||||
cls._custom_callbacks[marker].append((event, fn)) | cls._custom_callbacks[marker].append((event, fn)) | ||||
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ | |||||
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ | |||||
"`Callback` with the same callback time." | |||||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | |||||
_check_valid_parameters_number(fn, callback_fn_args) | |||||
return fn | return fn | ||||
return wrapper | return wrapper | ||||
@@ -431,9 +429,11 @@ class Trainer(TrainerEventTrigger): | |||||
2. 函数作用 | 2. 函数作用 | ||||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | ||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") / | |||||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") | |||||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | ||||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | ||||
@@ -443,10 +443,12 @@ class Trainer(TrainerEventTrigger): | |||||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | ||||
""" | """ | ||||
if check_mode: | if check_mode: | ||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
else: | else: | ||||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
_not_called_callback_fns = [] | _not_called_callback_fns = [] | ||||
for each_callback_fn in callbacks: | for each_callback_fn in callbacks: | ||||
if each_callback_fn in self.callback_manager.callback_fns: | if each_callback_fn in self.callback_manager.callback_fns: | ||||
@@ -498,8 +500,6 @@ class Trainer(TrainerEventTrigger): | |||||
@driver.setter | @driver.setter | ||||
def driver(self, driver: Driver): | def driver(self, driver: Driver): | ||||
driver.trainer = self | |||||
driver.model = self.model | |||||
self._driver = driver | self._driver = driver | ||||
@property | @property | ||||
@@ -591,7 +591,9 @@ class Trainer(TrainerEventTrigger): | |||||
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | ||||
# 2. trainer_state; | # 2. trainer_state; | ||||
states = {"callback_states": self.on_save_checkpoint(), | states = {"callback_states": self.on_save_checkpoint(), | ||||
"trainer_state": self.trainer_state.state_dict()} | |||||
"trainer_state": self.trainer_state.state_dict(), | |||||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||||
} | |||||
# 3. validate filter state; | # 3. validate filter state; | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
@@ -668,6 +670,10 @@ class Trainer(TrainerEventTrigger): | |||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||||
self.batch_idx_in_epoch | |||||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
# 5. 恢复所有 callback 的状态; | # 5. 恢复所有 callback 的状态; | ||||
self.on_load_checkpoint(states["callback_states"]) | self.on_load_checkpoint(states["callback_states"]) | ||||
@@ -692,13 +698,15 @@ class Trainer(TrainerEventTrigger): | |||||
def zero_grad(self): | def zero_grad(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_zero_grad(self.driver.optimizers) | |||||
self.on_before_zero_grad(self.optimizers) | |||||
self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
self.on_after_zero_grad(self.optimizers) | |||||
def step(self): | def step(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_optimizer_step(self.driver.optimizers) | |||||
self.on_before_optimizers_step(self.optimizers) | |||||
self.driver.step() | self.driver.step() | ||||
self.on_after_optimizers_step(self.optimizers) | |||||
def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
return self.driver.move_data_to_device(batch) | return self.driver.move_data_to_device(batch) | ||||
@@ -796,4 +804,19 @@ class Trainer(TrainerEventTrigger): | |||||
def total_batches(self, total_batches: int): | def total_batches(self, total_batches: int): | ||||
self.trainer_state.total_batches = total_batches | self.trainer_state.total_batches = total_batches | ||||
""" driver property """ | |||||
@property | |||||
def model_device(self): | |||||
return self.driver.model_device | |||||
@property | |||||
def data_device(self): | |||||
return self.driver.data_device | |||||
@@ -60,7 +60,7 @@ class TrainerState: | |||||
cur_epoch_idx: 当前正在运行第几个 epoch; | cur_epoch_idx: 当前正在运行第几个 epoch; | ||||
global_forward_batches: 当前模型总共 forward 了多少个 step; | global_forward_batches: 当前模型总共 forward 了多少个 step; | ||||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | ||||
total_batches: 每一个 epoch 会 forward 多少个 step; | |||||
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | ||||
""" | """ | ||||
n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
@@ -1,8 +1,9 @@ | |||||
from collections.abc import Iterator | |||||
import inspect | |||||
from typing import Dict | from typing import Dict | ||||
from fastNLP.core.callbacks import CallbackManager | from fastNLP.core.callbacks import CallbackManager | ||||
from .state import TrainerState | from .state import TrainerState | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class TrainerEventTrigger: | class TrainerEventTrigger: | ||||
@@ -68,12 +69,18 @@ class TrainerEventTrigger: | |||||
def on_after_backward(self): | def on_after_backward(self): | ||||
self.callback_manager.on_after_backward(self) | self.callback_manager.on_after_backward(self) | ||||
def on_before_optimizer_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||||
def on_before_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizers_step(self, optimizers) | |||||
def on_after_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_after_optimizers_step(self, optimizers) | |||||
def on_before_zero_grad(self, optimizers): | def on_before_zero_grad(self, optimizers): | ||||
self.callback_manager.on_before_zero_grad(self, optimizers) | self.callback_manager.on_before_zero_grad(self, optimizers) | ||||
def on_after_zero_grad(self, optimizers): | |||||
self.callback_manager.on_after_zero_grad(self, optimizers) | |||||
def on_validate_begin(self): | def on_validate_begin(self): | ||||
self.callback_manager.on_validate_begin(self) | self.callback_manager.on_validate_begin(self) | ||||
@@ -119,5 +126,8 @@ class _TruncatedDataLoader: | |||||
return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
def check_validate_every(validate_every): | |||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
if callable(validate_every): | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||||
""" | """ | ||||
:param dataset: 实现了__getitem__和__len__的数据容器 | :param dataset: 实现了__getitem__和__len__的数据容器 | ||||
@@ -178,10 +178,11 @@ class DataSet: | |||||
elif isinstance(idx, slice): | elif isinstance(idx, slice): | ||||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | ||||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | ||||
data_set = DataSet() | |||||
dataset = DataSet() | |||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||||
return data_set | |||||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
return dataset | |||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
raise KeyError("No such field called {} in DataSet.".format(idx)) | raise KeyError("No such field called {} in DataSet.".format(idx)) | ||||
@@ -192,6 +193,7 @@ class DataSet: | |||||
assert isinstance(i, int), "Only int index allowed." | assert isinstance(i, int), "Only int index allowed." | ||||
instance = self[i] | instance = self[i] | ||||
dataset.append(instance) | dataset.append(instance) | ||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -674,6 +676,8 @@ class DataSet: | |||||
dev_set.append(self[idx]) | dev_set.append(self[idx]) | ||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||||
train_set.collate_fns = deepcopy(self.collate_fns) | |||||
return dev_set, train_set | return dev_set, train_set | ||||
@@ -788,13 +792,14 @@ class DataSet: | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | ||||
""" | """ | ||||
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | 当val=None时,意味着给定的field_names都不需要尝试padding | ||||
:param field_names: dataset存在的field_name | :param field_names: dataset存在的field_name | ||||
:param val: 默认为0 | |||||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
# TODO 不能为空 | |||||
for field_name in field_names: | for field_name in field_names: | ||||
self.collate_fns.set_pad_val(field_name, val=val) | self.collate_fns.set_pad_val(field_name, val=val) | ||||
@@ -805,6 +810,7 @@ class DataSet: | |||||
:param field_names: | :param field_names: | ||||
:return: | :return: | ||||
""" | """ | ||||
# | |||||
self.collate_fns.set_input(*field_names) | self.collate_fns.set_input(*field_names) | ||||
def get_collator(self) -> _MultiCollator: | def get_collator(self) -> _MultiCollator: | ||||
@@ -66,7 +66,7 @@ class JittorDriver(Driver): | |||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | if not hasattr(model, "validate_step"): | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | |||||
logger.warning_once( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | "Your model does not have 'validate_step' method but has 'test_step' method, but you" | ||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | "are using 'mode=validate', we are going to use 'test_step' to substitute for" | ||||
"'validate_step'.") | "'validate_step'.") | ||||
@@ -74,7 +74,7 @@ class JittorDriver(Driver): | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | "are using 'mode=test', we are going to use 'validate_step' to substitute for" | ||||
"'test_step'.") | "'test_step'.") | ||||
@@ -10,6 +10,8 @@ from .utils import ( | |||||
_MODE_PARAMETER, | _MODE_PARAMETER, | ||||
get_device_from_visible, | get_device_from_visible, | ||||
reset_seed, | reset_seed, | ||||
replace_sampler, | |||||
replace_batch_sampler, | |||||
) | ) | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -19,8 +21,17 @@ from fastNLP.core.utils import ( | |||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||||
from fastNLP.core.samplers import ( | |||||
RandomBatchSampler, | |||||
ReproducibleSampler, | |||||
ReproducibleBatchSampler, | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
UnrepeatedSequentialSampler, | |||||
re_instantiate_sampler, | |||||
conversion_between_reproducible_and_unrepeated_sampler, | |||||
) | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -93,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self._model_device = None | self._model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
@@ -105,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): | |||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `train_step` method, which we can not call actually, we will" | "model also implements the `train_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._train_signature_fn = model.forward | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | "model also implements the `validate_step` method, which we can not call actually, " | ||||
"we will call `forward` function instead of `validate_step` and you should note that.") | "we will call `forward` function instead of `validate_step` and you should note that.") | ||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._validate_signature_fn = model.forward | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `test_step` method, which we can not call actually, we will" | "model also implements the `test_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `test_step` and you should note that.") | " call `forward` function instead of `test_step` and you should note that.") | ||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
@@ -235,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
""" | """ | ||||
if self.local_rank == 0: | if self.local_rank == 0: | ||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
print("in launcher") | |||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
# 设置参数和初始化分布式环境 | # 设置参数和初始化分布式环境 | ||||
@@ -253,7 +261,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | ||||
根据 paddle 设置的环境变量来获得各种属性 | 根据 paddle 设置的环境变量来获得各种属性 | ||||
""" | """ | ||||
print("set_from_env") | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
@@ -267,9 +274,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
**self._fleet_kwargs | **self._fleet_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | self._configured = True | ||||
@@ -312,67 +319,90 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | if isinstance(dist, ReproducibleSampler): | ||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员 | |||||
# 因此用 type() 进行严格的判断 | |||||
if type(dataloader.batch_sampler) == BatchSampler: | |||||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) | |||||
else: | |||||
shuffle = dataloader.batch_sampler.shuffle | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
batch_sampler.set_distributed( | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
return dataloader | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | ) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
sampler = UnrepeatedRandomSampler( | |||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
) | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
@@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if driver not in {"paddle", "fleet"}: | if driver not in {"paddle", "fleet"}: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | ||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | ||||
# 优先级 user > cuda | |||||
# 判断单机情况 device 的合法性 | |||||
# 分布式情况下通过 world_device 判断 | |||||
if user_visible_devices != "": | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | |||||
elif cuda_visible_devices is not None: | |||||
_could_use_device_num = len(cuda_visible_devices.split(",")) | |||||
else: | |||||
_could_use_device_num = paddle.device.cuda.device_count() | |||||
if user_visible_devices is None: | |||||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
if device < 0 and device != -1: | if device < 0 and device != -1: | ||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
# if device >= _could_use_device_num: | |||||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
device = f"gpu:{device}" | |||||
if device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
if device != -1: | |||||
device = f"gpu:{device}" | |||||
else: | |||||
device = list(range(_could_use_device_num)) | |||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
for each in device: | for each in device: | ||||
@@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | ||||
elif each < 0: | elif each < 0: | ||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | ||||
elif each >= _could_use_device_num: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||||
" the available gpu number.") | |||||
if len(device) == 1: | if len(device) == 1: | ||||
# 传入了 [1] 这样的,视为单卡。 | # 传入了 [1] 这样的,视为单卡。 | ||||
device = device[0] | device = device[0] | ||||
@@ -1,21 +1,36 @@ | |||||
import os | import os | ||||
import random | import random | ||||
from typing import Union, Optional, Callable, Dict | |||||
from typing import Union, Optional, Dict | |||||
from pathlib import Path | |||||
from functools import partial | from functools import partial | ||||
from dataclasses import dataclass | |||||
import numpy as np | import numpy as np | ||||
from .utils import _build_fp16_env | |||||
from .utils import _build_fp16_env, optimizer_state_to_device | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | ||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_SEED_WORKERS | |||||
from fastNLP.envs import ( | |||||
FASTNLP_SEED_WORKERS, | |||||
FASTNLP_MODEL_FILENAME, | |||||
FASTNLP_CHECKPOINT_FILENAME, | |||||
FASTNLP_GLOBAL_RANK, | |||||
rank_zero_call, | |||||
) | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, IterableDataset | |||||
from paddle.io import ( | |||||
DataLoader, | |||||
IterableDataset, | |||||
Dataset, | |||||
Sampler, | |||||
BatchSampler, | |||||
RandomSampler, | |||||
) | |||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
_reduces = { | _reduces = { | ||||
@@ -41,6 +56,9 @@ class PaddleDriver(Driver): | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
r""" | r""" | ||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | ||||
@@ -48,8 +66,8 @@ class PaddleDriver(Driver): | |||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | :param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | ||||
""" | """ | ||||
# if set_to_none: | |||||
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||||
if set_to_none: | |||||
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||||
for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
optimizer.clear_grad() | optimizer.clear_grad() | ||||
@@ -69,6 +87,8 @@ class PaddleDriver(Driver): | |||||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | ||||
if isinstance(dataloader.dataset, IterableDataset): | if isinstance(dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") | |||||
else: | else: | ||||
if not isinstance(dataloader, Dict): | if not isinstance(dataloader, Dict): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | ||||
@@ -79,6 +99,9 @@ class PaddleDriver(Driver): | |||||
f"type, not {type(each_dataloader)}.") | f"type, not {type(each_dataloader)}.") | ||||
if isinstance(each_dataloader.dataset, IterableDataset): | if isinstance(each_dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: | |||||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||||
f"`batch_sampler` and `batch_size` should be set.") | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
@@ -110,7 +133,7 @@ class PaddleDriver(Driver): | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | ||||
"'test_step'.") | "'test_step'.") | ||||
@@ -153,45 +176,55 @@ class PaddleDriver(Driver): | |||||
getattr(self.model, mode)() | getattr(self.model, mode)() | ||||
@rank_zero_call | @rank_zero_call | ||||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): | |||||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | ||||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | |||||
:param filepath: 保存文件的文件位置(需要包括文件名); | :param filepath: 保存文件的文件位置(需要包括文件名); | ||||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; | |||||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 | |||||
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 | |||||
:param kwargs: | |||||
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 | |||||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||||
的文档: | |||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||||
:return: | |||||
""" | """ | ||||
if model_save_fn is not None: | |||||
model_save_fn(filepath) | |||||
model = self.unwrap_model() | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
paddle.save(states, filepath) | |||||
else: | else: | ||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
paddle.save(model.state_dict(), filepath) | |||||
else: | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") | |||||
paddle.jit.save(model, filepath, input_spec) | |||||
# paddle 在保存整个模型时需要传入额外参数 | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||||
paddle.jit.save(model, filepath, input_spec) | |||||
@staticmethod | |||||
@rank_zero_call | |||||
def load_model(filepath: str, load_dict: bool = True): | |||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | ||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | :param filepath: 需要被加载的对象的文件位置(需要包括文件名); | ||||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | |||||
即保存了整个模型时,这个参数必须也为False | |||||
:return: 返回加载指定文件后的结果; | |||||
:param only_state_dict: 是否加载state_dict,默认为True。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | """ | ||||
if load_dict: | |||||
return paddle.load(filepath) | |||||
else: | |||||
return paddle.jit.load(filepath) | |||||
model = self.unwrap_model() | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict | |||||
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 | |||||
dirname, filename = os.path.split(filepath) | |||||
if not only_state_dict and dirname == "": | |||||
# 如果传入的是单个文件,则加上相对路径 | |||||
filepath = os.path.join(".", filepath) | |||||
model.load_dict(paddle.load(filepath)) | |||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder, states: Dict): | |||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | r""" | ||||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | ||||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | ||||
@@ -203,48 +236,114 @@ class PaddleDriver(Driver): | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | ||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | ||||
传入的值保持一致。 | 传入的值保持一致。 | ||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||||
:return: | |||||
""" | """ | ||||
# 1. 保存模型的状态; | |||||
model = self.unwrap_model() | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||||
states["model_state_dict"] = model_state_dict | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
sampler_states = sampler.state_dict() | |||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
try: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
# 2. 保存 optimizers 的状态; | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
if only_state_dict: | |||||
logger.debug("Save model state dict.") | |||||
else: | |||||
logger.debug("Save model.") | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: Optimizer = self.optimizers[i] | optimizer: Optimizer = self.optimizers[i] | ||||
optimizer_state = optimizer.state_dict() | optimizer_state = optimizer.state_dict() | ||||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | ||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, folder) | |||||
def load(self, filepath) -> Dict: | |||||
r""" | |||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; | |||||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 | |||||
因此 save 函数和 load 函数的接受和返回值应该是对应的; | |||||
该函数需要在所有 rank 上执行。 | |||||
logger.debug("Save optimizer state dict.") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
:param filepath: 保存断点重训的状态的文件名; | |||||
:return: 需要返回 save 函数输入的 states 内容; | |||||
""" | |||||
states = paddle.load(filepath) | |||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
optimizers_state_dict = states["optimizers_state_dict"] | optimizers_state_dict = states["optimizers_state_dict"] | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] | |||||
optimizer: Optimizer = self.optimizers[i] | |||||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | ||||
logger.debug("Load optimizer state dict.") | |||||
# 2. 加载模型状态; | # 2. 加载模型状态; | ||||
model = self.unwrap_model() | |||||
model.load_dict(states["model_state_dict"]) | |||||
if should_load_model: | |||||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) | |||||
if only_state_dict: | |||||
logger.debug("Load model state dict.") | |||||
else: | |||||
logger.debug("Load model.") | |||||
# 3. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
else: | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
self.barrier() | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -282,7 +381,7 @@ class PaddleDriver(Driver): | |||||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | ||||
""" | """ | ||||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | ||||
global_rank = rank if rank is not None else rank_zero_call.rank | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
# TODO gpu | # TODO gpu | ||||
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | ||||
# back out the base seed so we can use all the bits | # back out the base seed so we can use all the bits | ||||
@@ -313,3 +412,64 @@ class PaddleDriver(Driver): | |||||
""" | """ | ||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | dataloader.batch_sampler.set_epoch(cur_epoch_idx) | ||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# paddle 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
if dataloader.batch_sampler is not None: | |||||
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# RandomBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
res.sampler = None | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -2,6 +2,7 @@ import os | |||||
from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.utils import ( | from fastNLP.core.utils import ( | ||||
@@ -10,7 +11,12 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleBatchSampler, | |||||
RandomBatchSampler, | |||||
ReproducibleSampler, | |||||
re_instantiate_sampler, | |||||
) | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -22,16 +28,13 @@ __all__ = [ | |||||
] | ] | ||||
class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): | |||||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
if device is None: | if device is None: | ||||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | ||||
if isinstance(device, int): | |||||
self.model_device = get_paddle_gpu_str(device) | |||||
else: | |||||
self.model_device = device | |||||
self.model_device = get_paddle_gpu_str(device) | |||||
self.local_rank = 0 | self.local_rank = 0 | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
@@ -93,18 +96,18 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
def setup(self): | def setup(self): | ||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||||
device_id = get_paddle_device_id(self.model_device) | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
device_id = user_visible_devices.split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
paddle.device.set_device("gpu:0") | |||||
self.model.to("gpu:0") | |||||
device = self.model_device | |||||
if device != "cpu": | |||||
device_id = get_paddle_device_id(device) | |||||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
device = get_device_from_visible(device, output_type=str) | |||||
paddle.device.set_device(device) | |||||
self.model.to(device) | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
@@ -118,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
def test_step(self, batch) -> Dict: | def test_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
@@ -133,38 +136,40 @@ class PaddleSingleDriver(PaddleDriver): | |||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | ||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | ||||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||||
:return: 将移动到指定机器上的 batch 对象返回; | :return: 将移动到指定机器上的 batch 对象返回; | ||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | |||||
device = get_device_from_visible(self.data_device) | |||||
return paddle_move_data_to_device(batch, device) | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# 暂时不支持IteratorDataset | |||||
# 暂时不支持iterableDataset | |||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | |||||
"FastNLP does not support `IteratorDataset` now." | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dataloader.batch_sampler = dist | |||||
return dataloader | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
return dataloader | |||||
else: | |||||
# TODO | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
) | |||||
dataloader.batch_sampler = batch_sampler | |||||
return dataloader | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -4,12 +4,14 @@ import struct | |||||
import random | import random | ||||
import inspect | import inspect | ||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | |||||
from contextlib import ExitStack, closing | from contextlib import ExitStack, closing | ||||
from enum import IntEnum | from enum import IntEnum | ||||
from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call | |||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | |||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import nn | from paddle import nn | ||||
from paddle.nn import Layer | from paddle.nn import Layer | ||||
from paddle.io import DataLoader, BatchSampler | |||||
from paddle.io import DataLoader, BatchSampler, Dataset | |||||
from paddle.amp import auto_cast, GradScaler | from paddle.amp import auto_cast, GradScaler | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | from fastNLP.core.utils.dummy_class import DummyClass as Layer | ||||
@@ -85,7 +87,7 @@ class ForwardState(IntEnum): | |||||
TEST = 2 | TEST = 2 | ||||
PREDICT = 3 | PREDICT = 3 | ||||
_MODE_PARAMETER = "_forward_state" | |||||
_MODE_PARAMETER = "forward_state" | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
@@ -151,24 +153,25 @@ class _FleetWrappingModel(Layer): | |||||
def forward(self, batch, **kwargs) -> Dict: | def forward(self, batch, **kwargs) -> Dict: | ||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif _forward_state == ForwardState.PREDICT: | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | raise NotImplementedError("'PREDICT' mode has not been implemented.") | ||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | raise NotImplementedError("You should direct a concrete mode.") | ||||
@@ -205,7 +208,6 @@ class DummyGradScaler: | |||||
def state_dict(self): | def state_dict(self): | ||||
return {} | return {} | ||||
def _build_fp16_env(dummy=False): | def _build_fp16_env(dummy=False): | ||||
if dummy: | if dummy: | ||||
auto_cast = ExitStack | auto_cast = ExitStack | ||||
@@ -255,61 +257,77 @@ def get_host_name_ip(): | |||||
except: | except: | ||||
return None | return None | ||||
def get_device_from_visible(device: Union[str, int]): | |||||
def get_device_from_visible(device: Union[str, int], output_type=int): | |||||
""" | """ | ||||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | ||||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | ||||
:param devices:未转化的设备名 | |||||
:param device: 未转化的设备名 | |||||
:param output_type: 返回值的类型 | |||||
:return: 转化后的设备id | :return: 转化后的设备id | ||||
""" | """ | ||||
if output_type not in [int, str]: | |||||
raise ValueError("Parameter `output_type` should be one of these types: [int, str]") | |||||
if device == "cpu": | if device == "cpu": | ||||
return device | return device | ||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | ||||
idx = get_paddle_device_id(device) | idx = get_paddle_device_id(device) | ||||
if cuda_visible_devices is None or cuda_visible_devices == "": | if cuda_visible_devices is None or cuda_visible_devices == "": | ||||
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | # 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | ||||
return idx | |||||
raise RuntimeError("This situation should not happen, please report us this bug.") | |||||
else: | else: | ||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
idx = user_visible_devices.split(",")[idx] | |||||
else: | |||||
idx = str(idx) | |||||
if user_visible_devices is None: | |||||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||||
idx = user_visible_devices.split(",")[idx] | |||||
cuda_visible_devices_list = cuda_visible_devices.split(',') | cuda_visible_devices_list = cuda_visible_devices.split(',') | ||||
assert idx in cuda_visible_devices_list, "Can't find "\ | |||||
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ | |||||
% (idx, cuda_visible_devices) | |||||
if idx not in cuda_visible_devices_list: | |||||
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") | |||||
res = cuda_visible_devices_list.index(idx) | res = cuda_visible_devices_list.index(idx) | ||||
return res | |||||
if output_type == int: | |||||
return res | |||||
else: | |||||
return f"gpu:{res}" | |||||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
# 拿到实例属性; | |||||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | |||||
""" | |||||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 | |||||
考虑了用户自己定制了 DataLoader 的情形。 | |||||
""" | |||||
# 拿到非下划线开头的实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | ||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | ||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | ||||
# 中寻找; | |||||
# 中寻找;VAR_KEYWORD 代表 **kwargs | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | ||||
del init_params["self"] | del init_params["self"] | ||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | ||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
non_default_params = {name for name, p in init_params.items() if | non_default_params = {name for name, p in init_params.items() if | ||||
name in instance_attrs and p.default != instance_attrs[name]} | name in instance_attrs and p.default != instance_attrs[name]} | ||||
# add `dataset` as it might have been replaced with `*args` | # add `dataset` as it might have been replaced with `*args` | ||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
# 收集不是默认值的参数和它的值 | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) | |||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
reconstruct_args.update({ | |||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||||
"persistent_workers": dataloader._persistent_workers, | |||||
}) | |||||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | |||||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | |||||
# 也即它们没有在初始化函数和实例成员中同时出现 | |||||
required_args = { | required_args = { | ||||
p.name | p.name | ||||
for p in init_params.values() | for p in init_params.values() | ||||
@@ -323,12 +341,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | "This would fail as some of the `__init__` arguments are not available as instance attributes. " | ||||
f"The missing attributes are {required_args}. " | f"The missing attributes are {required_args}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
@@ -340,12 +355,28 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | "This would fail as it doesn't expose all its attributes in the `__init__` signature. " | ||||
f"The missing arguments are {missing_kwargs}. " | f"The missing arguments are {missing_kwargs}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
def replace_sampler(dataloader, new_sampler): | |||||
""" | |||||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 | |||||
""" | |||||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | |||||
new_batch_sampler.sampler = new_sampler | |||||
return replace_batch_sampler(dataloader, new_batch_sampler) | |||||
def optimizer_state_to_device(state, device): | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, paddle.Tensor): | |||||
new_state[name] = paddle_to(param, device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state |
@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: | |||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import BatchSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchDDPDriver' | 'TorchDDPDriver' | ||||
@@ -167,6 +168,7 @@ class TorchDDPDriver(TorchDriver): | |||||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | 不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | ||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | ||||
""" | """ | ||||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
@@ -202,8 +204,8 @@ class TorchDDPDriver(TorchDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self.model_device = None | self.model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return step_fn(batch) | return step_fn(batch) | ||||
@@ -214,7 +216,7 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `train_step` method, which we can not call actually, we will" | "model also implements the `train_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._train_signature_fn = model.forward | # self._train_signature_fn = model.forward | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
@@ -222,7 +224,7 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | "model also implements the `validate_step` method, which we can not call actually, " | ||||
"we will call `forward` function instead of `validate_step` and you should note that.") | "we will call `forward` function instead of `validate_step` and you should note that.") | ||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._validate_signature_fn = model.forward | # self._validate_signature_fn = model.forward | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
@@ -230,14 +232,11 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `test_step` method, which we can not call actually, we will" | "model also implements the `test_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `test_step` and you should note that.") | " call `forward` function instead of `test_step` and you should note that.") | ||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._test_signature_fn = model.forward | # self._test_signature_fn = model.forward | ||||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
# if self.outside_ddp and self._data_device is None: | |||||
# raise RuntimeError("When you initialize your ddp out of our control, the parameter " | |||||
# "`data_device` can not be None.") | |||||
if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
if self._data_device < 0: | if self._data_device < 0: | ||||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | raise ValueError("Parameter `data_device` can not be smaller than 0.") | ||||
@@ -349,9 +348,9 @@ class TorchDDPDriver(TorchDriver): | |||||
**self._ddp_kwargs | **self._ddp_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | self._configured = True | ||||
@@ -472,12 +471,11 @@ class TorchDDPDriver(TorchDriver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
@@ -526,18 +524,11 @@ class TorchDDPDriver(TorchDriver): | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
@@ -3,28 +3,20 @@ import pickle | |||||
_pickler = pickle.Pickler | _pickler = pickle.Pickler | ||||
_unpickler = pickle.Unpickler | _unpickler = pickle.Unpickler | ||||
from typing import Any, List | from typing import Any, List | ||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupMPI | |||||
except ImportError: | |||||
_MPI_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||||
except ImportError: | |||||
_NCCL_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupGloo | |||||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
except ImportError: | |||||
_GLOO_AVAILABLE = False | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupGloo | |||||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
except ImportError: | |||||
pass | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
) | ) | ||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||||
""" | """ | ||||
从其它 rank gather 东西到 dst rank 。 | 从其它 rank gather 东西到 dst rank 。 | ||||
@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) | |||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
""" | """ | ||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
if dist.distributed_c10d._rank_not_in_group(group): | if dist.distributed_c10d._rank_not_in_group(group): | ||||
return | return | ||||
@@ -193,7 +188,7 @@ def _to_device(tensor, device): | |||||
return tensor.contiguous().to(device) | return tensor.contiguous().to(device) | ||||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||||
""" | """ | ||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | ||||
@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
:param group: | :param group: | ||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | ||||
""" | """ | ||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
if isinstance(obj, torch.Tensor): | if isinstance(obj, torch.Tensor): | ||||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | ||||
dist.all_gather(objs, obj, group=group) | dist.all_gather(objs, obj, group=group) | ||||
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
return objs | return objs | ||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||||
""" | """ | ||||
将 src 上的 obj 对象广播到其它 rank 上。 | 将 src 上的 obj 对象广播到其它 rank 上。 | ||||
@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||||
:param group: | :param group: | ||||
:return: | :return: | ||||
""" | """ | ||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
cur_rank = dist.get_rank(group) | cur_rank = dist.get_rank(group) | ||||
if cur_rank == src: | if cur_rank == src: | ||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | ||||
@@ -335,19 +333,21 @@ def all_gather_object(object_list, obj, group=None): | |||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
""" | """ | ||||
if dist._rank_not_in_group(group): | |||||
if dist.distributed_c10d._rank_not_in_group(group): | |||||
return | return | ||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = _check_for_nccl_backend(group) | |||||
if is_nccl_backend: | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
else: | |||||
current_device = torch.cuda.current_device() | |||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = _check_for_nccl_backend(group) | |||||
if is_nccl_backend: | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | # Gather all local sizes. This is so that we can find the max size, and index | ||||
# until the correct size when deserializing the tensors. | # until the correct size when deserializing the tensors. | ||||
group_size = dist.get_world_size(group=group) | group_size = dist.get_world_size(group=group) | ||||
@@ -378,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): | |||||
tensor = tensor.cpu() | tensor = tensor.cpu() | ||||
tensor_size = object_size_list[i] | tensor_size = object_size_list[i] | ||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | object_list[i] = _tensor_to_object(tensor, tensor_size) | ||||
return object_list |
@@ -13,7 +13,7 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -102,29 +102,21 @@ class TorchSingleDriver(TorchDriver): | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | ||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
def test_step(self, batch) -> Dict: | def test_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
@@ -148,7 +140,7 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -51,6 +51,9 @@ class TorchDriver(Driver): | |||||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | ||||
self.non_blocking = kwargs.get("torch_non_blocking", True) | self.non_blocking = kwargs.get("torch_non_blocking", True) | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
self._clear_grad(optimizer, set_to_none) | self._clear_grad(optimizer, set_to_none) | ||||
@@ -69,6 +72,14 @@ class TorchDriver(Driver): | |||||
p.grad.requires_grad_(False) | p.grad.requires_grad_(False) | ||||
p.grad.zero_() | p.grad.zero_() | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
if is_train: | if is_train: | ||||
@@ -102,7 +113,7 @@ class TorchDriver(Driver): | |||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | if not hasattr(model, "validate_step"): | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | |||||
logger.warning_once( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | "Your model does not have 'validate_step' method but has 'test_step' method, but you" | ||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | "are using 'mode=validate', we are going to use 'test_step' to substitute for" | ||||
"'validate_step'.") | "'validate_step'.") | ||||
@@ -191,9 +202,20 @@ class TorchDriver(Driver): | |||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
else: | else: | ||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
states['sampler_states'] = sampler.state_dict() | |||||
sampler_states = sampler.state_dict() | |||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
try: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
@@ -252,7 +274,7 @@ class TorchDriver(Driver): | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = ReproducibleBatchSampler( | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -8,6 +8,7 @@ import numpy as np | |||||
import inspect | import inspect | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -140,24 +141,25 @@ class _DDPWrappingModel(Module): | |||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | ||||
""" | """ | ||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif _forward_state == ForwardState.PREDICT: | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | raise NotImplementedError("'PREDICT' mode has not been implemented.") | ||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | raise NotImplementedError("You should direct a concrete mode.") | ||||
@@ -294,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
"manually add the `DistributedSampler` as: " | "manually add the `DistributedSampler` as: " | ||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | ||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
@@ -306,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||||
""" | """ | ||||
batch_sampler = getattr(dataloader, "batch_sampler") | batch_sampler = getattr(dataloader, "batch_sampler") | ||||
# checking the batch sampler type is different than PyTorch default. | # checking the batch sampler type is different than PyTorch default. | ||||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||||
batch_sampler = type(batch_sampler)( | |||||
sampler, | |||||
batch_size=batch_sampler.batch_size, | |||||
drop_last=batch_sampler.drop_last, | |||||
) | |||||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||||
return { | return { | ||||
"sampler": None, | "sampler": None, | ||||
@@ -342,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | params = {k: getattr(dataloader, k) for k in params_keys} | ||||
params["batch_sampler"] = new_batch_sampler | params["batch_sampler"] = new_batch_sampler | ||||
return type(dataloader)(**params) | return type(dataloader)(**params) | ||||
# TODO 这里是否可以auto_param_call一下 | |||||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||||
# signature_fn=type(dataloader).__init__) | |||||
def optimizer_state_to_device(state, device): | def optimizer_state_to_device(state, device): | ||||
@@ -51,6 +51,7 @@ class LoggerSingleton(type): | |||||
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | ||||
def __init__(self, name): | def __init__(self, name): | ||||
super().__init__(name) | super().__init__(name) | ||||
self._warning_msgs = set() | |||||
def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | ||||
mode: str = "w"): | mode: str = "w"): | ||||
@@ -108,10 +109,25 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
kwargs = self._add_rank_info(kwargs) | kwargs = self._add_rank_info(kwargs) | ||||
self._log(WARNING, msg, args, **kwargs) | self._log(WARNING, msg, args, **kwargs) | ||||
def warning_once(self, msg, *args, **kwargs): | |||||
""" | |||||
通过 warning 内容只会 warning 一次 | |||||
:param msg: | |||||
:param args: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
if msg not in self._warning_msgs: | |||||
if self.isEnabledFor(WARNING): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
self._warning_msgs.add(msg) | |||||
def warn(self, msg, *args, **kwargs): | def warn(self, msg, *args, **kwargs): | ||||
warnings.warn("The 'warn' method is deprecated, " | |||||
"use 'warning' instead", DeprecationWarning, 2) | |||||
self.warning(msg, *args, **kwargs) | |||||
if self.isEnabledFor(WARNING): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
def error(self, msg, *args, **kwargs): | def error(self, msg, *args, **kwargs): | ||||
""" | """ | ||||
@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||||
class Accuracy(Metric): | class Accuracy(Metric): | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||||
aggregate_when_get_metric: bool = True): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | ||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | ||||
@@ -64,7 +63,7 @@ class Accuracy(Metric): | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | ||||
else: | else: | ||||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||||
f"{pred.shape[:-1]}, got {target.shape}.") | f"{pred.shape[:-1]}, got {target.shape}.") | ||||
if masks is not None: | if masks is not None: | ||||
@@ -23,14 +23,14 @@ __all__ = [ | |||||
"BucketedBatchSampler", | "BucketedBatchSampler", | ||||
"ReproducibleBatchSampler", | "ReproducibleBatchSampler", | ||||
"re_instantiate_sampler", | |||||
"conversion_between_reproducible_and_unrepeated_sampler" | |||||
"re_instantiate_sampler" | |||||
] | ] | ||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | ||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | ||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | ||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from .utils import re_instantiate_sampler | |||||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | |||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | ||||
@@ -0,0 +1,33 @@ | |||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ | |||||
SortedSampler | |||||
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ | |||||
UnrepeatedSequentialSampler, UnrepeatedSortedSampler | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
""" | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
:param sampler: | |||||
:return: | |||||
""" | |||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
else: | |||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") |
@@ -4,16 +4,18 @@ __all__ = [ | |||||
] | ] | ||||
import math | import math | ||||
from array import array | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
from itertools import chain | from itertools import chain | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .utils import create_array, NumConsumedSamplesArray | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
class ReproducibleBatchSampler: | class ReproducibleBatchSampler: | ||||
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||||
@abstractmethod | @abstractmethod | ||||
def state_dict(self): | def state_dict(self): | ||||
""" | |||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | |||||
""" | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | ||||
@abstractmethod | @abstractmethod | ||||
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.data_idx = kwargs.get("data_idx", 0) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||||
self.index_list = kwargs.get("index_list", self._iterate_sampler()) | self.index_list = kwargs.get("index_list", self._iterate_sampler()) | ||||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | self.need_reinitialize = kwargs.get("need_reinitialize", False) | ||||
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | ||||
else: | else: | ||||
_index_lst.append(idx) | _index_lst.append(idx) | ||||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
if len(_index_lst) > 4294967295: | |||||
# 注意 self.index_list 内存放的是全部数据的 index; | |||||
# unsigned long | |||||
_index_lst = array("L", _index_lst) | |||||
else: | |||||
# unsigned int | |||||
_index_lst = array("I", _index_lst) | |||||
_index_lst = create_array(len(_index_lst), _index_lst) | |||||
return _index_lst | return _index_lst | ||||
def __iter__(self): | def __iter__(self): | ||||
if self.need_reinitialize: | if self.need_reinitialize: | ||||
self.index_list = self._iterate_sampler() | self.index_list = self._iterate_sampler() | ||||
self.data_idx = 0 | |||||
self.num_consumed_samples = 0 | |||||
else: | else: | ||||
self.need_reinitialize = True | self.need_reinitialize = True | ||||
batch = [] | batch = [] | ||||
if self.data_idx: | |||||
index_list = self.index_list[self.data_idx:] | |||||
if self.num_consumed_samples: | |||||
index_list = self.index_list[self.num_consumed_samples:] | |||||
else: | else: | ||||
index_list = self.index_list | index_list = self.index_list | ||||
# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||||
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx in index_list: | for idx in index_list: | ||||
batch.append(idx) | batch.append(idx) | ||||
self.data_idx += 1 | |||||
if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
batch = [] | batch = [] | ||||
if len(batch) > 0 and not self.drop_last: | if len(batch) > 0 and not self.drop_last: | ||||
self.num_consumed_samples += len(batch) | |||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
# 需要重置防止边界条件问题 | |||||
self.num_consumed_samples = 0 | |||||
delattr(self, 'num_consumed_samples_array') | |||||
def __len__(self) -> int: | def __len__(self) -> int: | ||||
if self.drop_last: | if self.drop_last: | ||||
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size | return (len(self.index_list) + self.batch_size - 1) // self.batch_size | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
states = { | |||||
"index_list": deepcopy(self.index_list), | |||||
"num_consumed_samples": self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__ | |||||
} | |||||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
return states | |||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | ||||
@@ -128,11 +147,11 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | ||||
"record and current dataset." | "record and current dataset." | ||||
self.index_list = _index_list | self.index_list = _index_list | ||||
self.data_idx = states["data_idx"] | |||||
self.num_consumed_samples = states["num_consumed_samples"] | |||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | ||||
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||||
else: | else: | ||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | ||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
@@ -166,8 +185,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.rank = rank | self.rank = rank | ||||
self.pad = pad | self.pad = pad | ||||
num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
else len(self.dataset) | |||||
if self.drop_last: | |||||
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
"than the number of replicates multiplied " \ | |||||
"with batch_size when drop_last=True." | |||||
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
# else len(self.dataset) | |||||
# | |||||
# if self.drop_last: | |||||
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
# "than the number of replicates multiplied " \ | |||||
# "with batch_size when drop_last=True." | |||||
return self | return self | ||||
@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | ||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
def __len__(self): | |||||
def __len__(self)->int: | |||||
""" | """ | ||||
返回当前 sampler 还会返回多少个 batch 的数据 | 返回当前 sampler 还会返回多少个 batch 的数据 | ||||
@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | ||||
batches = batches[:-1] | batches = batches[:-1] | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for batch in batches: | for batch in batches: | ||||
self.num_consumed_samples += self.num_replicas * len(batch) | self.num_consumed_samples += self.num_replicas * len(batch) | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield list(map(int, batch)) | yield list(map(int, batch)) | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
self.old_batch_size = self.batch_size | self.old_batch_size = self.batch_size | ||||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | self.old_num_batch_per_bucket = self.num_batch_per_bucket | ||||
self.old_num_replicas = self.num_replicas | self.old_num_replicas = self.num_replicas | ||||
@@ -356,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
batch_indices = list(batch_indices[:-1]) | batch_indices = list(batch_indices[:-1]) | ||||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | ||||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | ||||
batches = (np.array(batches)[batch_indices]).tolist() | |||||
batches = (np.array(batches, dtype=object)[batch_indices]).tolist() | |||||
if last_batches: | if last_batches: | ||||
batches = batches + last_batches | batches = batches + last_batches | ||||
return batches | return batches | ||||
@@ -365,21 +387,16 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | ||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas | |||||
} | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -1,16 +1,21 @@ | |||||
__all__ = [ | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | |||||
from typing import Dict, List, Union | from typing import Dict, List, Union | ||||
import math | import math | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
from .utils import NumConsumedSamplesArray | |||||
__all__ = [ | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
@@ -30,6 +35,13 @@ class ReproducibleSampler: | |||||
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | ||||
def state_dict(self): | def state_dict(self): | ||||
""" | |||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | |||||
""" | |||||
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | ||||
def load_state_dict(self, states): | def load_state_dict(self, states): | ||||
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||||
indices = indices[self.num_consumed_samples:] | indices = indices[self.num_consumed_samples:] | ||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||||
return indices | return indices | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle | |||||
} | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
if self.shuffle != states['shuffle']: | if self.shuffle != states['shuffle']: | ||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | ||||
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||||
return list(range(len(self.dataset))) | return list(range(len(self.dataset))) | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = { | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
} | |||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -295,8 +307,8 @@ class SortedSampler(SequentialSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, **kwargs) | super().__init__(dataset=dataset, **kwargs) | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -1,42 +1,65 @@ | |||||
__all__ = [ | __all__ = [ | ||||
're_instantiate_sampler', | |||||
'conversion_between_reproducible_and_unrepeated_sampler' | |||||
're_instantiate_sampler' | |||||
] | ] | ||||
from array import array | |||||
from typing import Sequence | |||||
from collections import deque | |||||
from fastNLP.core.samplers.unrepeated_sampler import * | |||||
from fastNLP.core.samplers.reproducible_sampler import * | |||||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
all_attributes = vars(sampler) | |||||
if new_sampler_class is not None: | |||||
return new_sampler_class(**all_attributes) | |||||
return type(sampler)(**all_attributes) | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
def create_array(length, fill_value) -> array: | |||||
""" | """ | ||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||||
:param sampler: | |||||
:param length: | |||||
:param fill_value: | |||||
:return: | :return: | ||||
""" | """ | ||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
if not isinstance(fill_value, Sequence): | |||||
fill_value = [fill_value]*length | |||||
if length > 4294967295: | |||||
_index_lst = array("L", fill_value) | |||||
else: | else: | ||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||||
_index_lst = array("I", fill_value) | |||||
return _index_lst | |||||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
all_attributes = vars(sampler) | |||||
if new_sampler_class is not None: | |||||
return new_sampler_class(**all_attributes) | |||||
return type(sampler)(**all_attributes) | |||||
class NumConsumedSamplesArray: | |||||
def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||||
""" | |||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||||
ex: | |||||
array = NumConsumedSamplesArray(buffer_size=3) | |||||
for i in range(10): | |||||
array.push(i) | |||||
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||||
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||||
:param buffer_size: 报错多少个历史。 | |||||
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||||
""" | |||||
self.count = 0 | |||||
self.deque = deque(maxlen=buffer_size) | |||||
if num_consumed_samples is not None: | |||||
self.push(num_consumed_samples) | |||||
self.buffer_size = buffer_size | |||||
def __getitem__(self, item): | |||||
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||||
return 0 | |||||
assert isinstance(item, int), "Only int index allowed." | |||||
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||||
index = len(self.deque) - (self.count - item) | |||||
return self.deque[index] | |||||
def push(self, num_consumed_samples): | |||||
self.deque.append(num_consumed_samples) | |||||
self.count += 1 |
@@ -13,7 +13,6 @@ __all__ = [ | |||||
'torch_paddle_move_data_to_device', | 'torch_paddle_move_data_to_device', | ||||
'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
'dataclass_to_dict', | 'dataclass_to_dict', | ||||
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_paddle_utils import torch_paddle_move_data_to_device | from .torch_paddle_utils import torch_paddle_move_data_to_device | ||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | ||||
@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]): | |||||
device = device.lower() | device = device.lower() | ||||
if device == "cpu": | if device == "cpu": | ||||
raise ValueError("Cannot get device id from `cpu`.") | raise ValueError("Cannot get device id from `cpu`.") | ||||
elif device == "gpu": | |||||
return 0 | |||||
match_res = re.match(r"gpu:\d+", device) | match_res = re.match(r"gpu:\d+", device) | ||||
if not match_res: | if not match_res: | ||||
raise ValueError( | raise ValueError( | ||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', " | |||||
f"not '{device}'" | |||||
) | ) | ||||
device_id = device.split(':', 1)[1] | device_id = device.split(':', 1)[1] | ||||
device_id = int(device_id) | device_id = int(device_id) | ||||
@@ -6,7 +6,7 @@ | |||||
import sys | import sys | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample | |||||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | ||||
__all__ = [ | __all__ = [ | ||||
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
if task_id in self._tasks: | if task_id in self._tasks: | ||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
self.console.show_cursor(show=True) | self.console.show_cursor(show=True) | ||||
def update( | |||||
self, | |||||
task_id: TaskID, | |||||
*, | |||||
total: Optional[float] = None, | |||||
completed: Optional[float] = None, | |||||
advance: Optional[float] = None, | |||||
description: Optional[str] = None, | |||||
visible: Optional[bool] = None, | |||||
refresh: bool = False, | |||||
**fields: Any, | |||||
) -> None: | |||||
"""Update information associated with a task. | |||||
Args: | |||||
task_id (TaskID): Task id (returned by add_task). | |||||
total (float, optional): Updates task.total if not None. | |||||
completed (float, optional): Updates task.completed if not None. | |||||
advance (float, optional): Add a value to task.completed if not None. | |||||
description (str, optional): Change task description if not None. | |||||
visible (bool, optional): Set visible flag if not None. | |||||
refresh (bool): Force a refresh of progress information. Default is False. | |||||
**fields (Any): Additional data fields required for rendering. | |||||
""" | |||||
with self._lock: | |||||
task = self._tasks[task_id] | |||||
completed_start = task.completed | |||||
if total is not None and total != task.total: | |||||
task.total = total | |||||
task._reset() | |||||
if advance is not None: | |||||
task.completed += advance | |||||
if completed is not None: | |||||
task.completed = completed | |||||
if description is not None: | |||||
task.description = description | |||||
if visible is not None: | |||||
task.visible = visible | |||||
task.fields.update(fields) | |||||
update_completed = task.completed - completed_start | |||||
current_time = self.get_time() | |||||
old_sample_time = current_time - self.speed_estimate_period | |||||
_progress = task._progress | |||||
popleft = _progress.popleft | |||||
# 这里修改为至少保留一个,防止超长时间的迭代影响判断 | |||||
while len(_progress)>1 and _progress[0].timestamp < old_sample_time: | |||||
popleft() | |||||
if update_completed > 0: | |||||
_progress.append(ProgressSample(current_time, update_completed)) | |||||
if task.completed >= task.total and task.finished_time is None: | |||||
task.finished_time = task.elapsed | |||||
if refresh: | |||||
self.refresh() | |||||
class SpeedColumn(ProgressColumn): | |||||
""" | |||||
显示 task 的速度。 | |||||
""" | |||||
def render(self, task: "Task"): | |||||
speed = task.speed | |||||
if speed is None: | |||||
return Text('-- it./s', style='progress.data.speed') | |||||
if speed > 0.1: | |||||
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') | |||||
else: | |||||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | ||||
f_rich_progress = FRichProgress().new_progess( | f_rich_progress = FRichProgress().new_progess( | ||||
"[progress.description]{task.description}", | "[progress.description]{task.description}", | ||||
"[progress.percentage]{task.percentage:>3.0f}%", | "[progress.percentage]{task.percentage:>3.0f}%", | ||||
BarColumn(), | BarColumn(), | ||||
SpeedColumn(), | |||||
TimeElapsedColumn(), | TimeElapsedColumn(), | ||||
"/", | "/", | ||||
TimeRemainingColumn(), | TimeRemainingColumn(), | ||||
TextColumn("{task.fields[post_desc]}", justify="right"), | TextColumn("{task.fields[post_desc]}", justify="right"), | ||||
transient=True, | transient=True, | ||||
disable=False, | disable=False, | ||||
speed_estimate_period=1 | |||||
speed_estimate_period=30 | |||||
) | ) | ||||
else: | else: | ||||
f_rich_progress = DummyFRichProgress() | f_rich_progress = DummyFRichProgress() | ||||
@@ -1,9 +1,11 @@ | |||||
from abc import ABC | from abc import ABC | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 | |||||
DEFAULT_TORCH_GROUP = None | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
if not _TORCH_GREATER_EQUAL_1_8: | |||||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||||
__all__ = [ | __all__ = [ | ||||
'torch_move_data_to_device' | 'torch_move_data_to_device' | ||||
@@ -1,3 +1,4 @@ | |||||
import functools | |||||
import inspect | import inspect | ||||
from inspect import Parameter | from inspect import Parameter | ||||
import dataclasses | import dataclasses | ||||
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | from fastNLP.envs import FASTNLP_GLOBAL_RANK | ||||
__all__ = [ | __all__ = [ | ||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
'dataclass_to_dict', | 'dataclass_to_dict', | ||||
@@ -44,48 +43,23 @@ __all__ = [ | |||||
] | ] | ||||
def get_fn_arg_names(fn: Callable) -> List[str]: | def get_fn_arg_names(fn: Callable) -> List[str]: | ||||
r""" | r""" | ||||
返回一个函数的所有参数的名字; | 返回一个函数的所有参数的名字; | ||||
:param fn: 需要查询的函数; | :param fn: 需要查询的函数; | ||||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | ||||
""" | """ | ||||
return list(inspect.signature(fn).parameters) | return list(inspect.signature(fn).parameters) | ||||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||||
r""" | |||||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||||
""" | |||||
if fn is None: | |||||
return True | |||||
if not callable(fn): | |||||
return False | |||||
else: | |||||
params = inspect.signature(fn).parameters | |||||
not_default_params = {} | |||||
for _name, _param in params.items(): | |||||
if _param.default == Parameter.empty: | |||||
not_default_params[_name] = _param | |||||
return len(not_default_params) == param_num | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | ||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | ||||
r""" | r""" | ||||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | ||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | ||||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | ||||
@@ -113,6 +87,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | ||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | ||||
""" | """ | ||||
if signature_fn is not None: | if signature_fn is not None: | ||||
if not callable(signature_fn): | if not callable(signature_fn): | ||||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | ||||
@@ -122,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
_kwargs = None | _kwargs = None | ||||
for _name, _param in _need_params.items(): | for _name, _param in _need_params.items(): | ||||
if _param.kind == Parameter.VAR_POSITIONAL: | if _param.kind == Parameter.VAR_POSITIONAL: | ||||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") | |||||
if _param.kind == Parameter.VAR_KEYWORD: | if _param.kind == Parameter.VAR_KEYWORD: | ||||
_kwargs = (_name, _param) | _kwargs = (_name, _param) | ||||
@@ -135,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
_default_params[_name] = _param.default | _default_params[_name] = _param.default | ||||
if mapping is not None: | if mapping is not None: | ||||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ | |||||
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||||
_has_params = {} | _has_params = {} | ||||
duplicate_names = [] | duplicate_names = [] | ||||
for arg in args: | for arg in args: | ||||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||||
if not isinstance(arg, Dict): | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise TypeError(f"Exception happens when calling {fn_msg}. " | |||||
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") | |||||
for _name, _value in arg.items(): | for _name, _value in arg.items(): | ||||
if mapping is not None and _name in mapping: | if mapping is not None and _name in mapping: | ||||
_name = mapping[_name] | _name = mapping[_name] | ||||
@@ -152,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
elif _name in _need_params and not (_has_params[_name] is _value): | elif _name in _need_params and not (_has_params[_name] is _value): | ||||
duplicate_names.append(_name) | duplicate_names.append(_name) | ||||
if duplicate_names: | if duplicate_names: | ||||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") | |||||
# 将具有默认值但是没有被输入修改过的参数值传进去; | # 将具有默认值但是没有被输入修改过的参数值传进去; | ||||
for _name, _value in _default_params.items(): | for _name, _value in _default_params.items(): | ||||
@@ -161,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
if len(_has_params)<len(_need_params): | if len(_has_params)<len(_need_params): | ||||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | ||||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
_provided_keys = _get_keys(args) | |||||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} " | |||||
f"are not found in the input keys({_provided_keys}).") | |||||
return fn(**_has_params) | return fn(**_has_params) | ||||
def _get_keys(args:List[Dict]) -> List[List[str]]: | |||||
""" | |||||
返回每个 dict 的 keys | |||||
:param args: | |||||
:return: | |||||
""" | |||||
_provided_keys = [] | |||||
for arg in args: | |||||
_provided_keys.append(list(arg.keys())) | |||||
return _provided_keys | |||||
def _get_fun_msg(fn)->str: | |||||
""" | |||||
获取函数的基本信息,帮助报错。 | |||||
ex: | |||||
print(_get_fun_msg(_get_fun_msg)) | |||||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||||
:param callable fn: | |||||
:return: | |||||
""" | |||||
if isinstance(fn, functools.partial): | |||||
return _get_fun_msg(fn.func) | |||||
try: | |||||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||||
except: | |||||
fn_name = str(fn) | |||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
fp = '' | |||||
msg = f'`{fn_name}`' + fp | |||||
return msg | |||||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
""" | |||||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||||
进行报错。 | |||||
:param fn: 需要检测的函数,可以是 method 或者 function 。 | |||||
:param expected_params: 期待应该支持的参数。 | |||||
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 | |||||
:return: | |||||
""" | |||||
if fn_name is not None: | |||||
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||||
parameters = list(inspect.signature(fn).parameters.values()) | |||||
if inspect.ismethod(fn): | |||||
if len(parameters)>0 and parameters[0].name == 'self': | |||||
parameters = parameters[1:] # 去掉self | |||||
no_var_param = True # 没有 * 这种参数 | |||||
number_param_need_value = 0 | |||||
for param in parameters: | |||||
if param.kind is param.VAR_POSITIONAL: | |||||
no_var_param = False | |||||
elif param.kind is param.VAR_KEYWORD: | |||||
no_var_param = False | |||||
else: | |||||
if param.default is param.empty: | |||||
number_param_need_value += 1 | |||||
if len(parameters)<len(expected_params) and no_var_param: | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
if number_param_need_value>len(expected_params): | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
def check_user_specific_params(user_params: Dict, fn: Callable): | def check_user_specific_params(user_params: Dict, fn: Callable): | ||||
""" | """ | ||||
该函数使用用户的输入来对指定函数的参数进行赋值; | 该函数使用用户的输入来对指定函数的参数进行赋值; | ||||
@@ -184,7 +244,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
return user_params | return user_params | ||||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||||
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | |||||
if not is_dataclass(data): | if not is_dataclass(data): | ||||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | ||||
_dict = dict() | _dict = dict() | ||||
@@ -591,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
wait_to_success(path.exists) | wait_to_success(path.exists) | ||||
def get_class_that_defined_method(method): | |||||
""" | |||||
给定一个method,返回这个 method 的 class 的对象 | |||||
:param method: | |||||
:return: | |||||
""" | |||||
if isinstance(method, functools.partial): | |||||
return get_class_that_defined_method(method.func) | |||||
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): | |||||
for cls in inspect.getmro(method.__self__.__class__): | |||||
if method.__name__ in cls.__dict__: | |||||
return cls | |||||
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing | |||||
if inspect.isfunction(method): | |||||
cls = getattr(inspect.getmodule(method), | |||||
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||||
None) | |||||
if isinstance(cls, type): | |||||
return cls | |||||
return getattr(method, '__objclass__', None) # handle special descriptor objects |
@@ -6,7 +6,8 @@ __all__ = [ | |||||
'is_cur_env_distributed', | 'is_cur_env_distributed', | ||||
'get_global_rank', | 'get_global_rank', | ||||
'rank_zero_call', | 'rank_zero_call', | ||||
'all_rank_call' | |||||
'all_rank_call', | |||||
'get_gpu_count' | |||||
] | ] | ||||
@@ -14,5 +15,5 @@ from .env import * | |||||
from .set_env_on_import import set_env_on_import | from .set_env_on_import import set_env_on_import | ||||
from .set_backend import dump_fastnlp_backend | from .set_backend import dump_fastnlp_backend | ||||
from .imports import * | from .imports import * | ||||
from .utils import _module_available | |||||
from .utils import _module_available, get_gpu_count | |||||
from .distributed import * | from .distributed import * |
@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||||
# todo 注释 | # todo 注释 | ||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
# fastNLP 中初始化deque的默认大小 | |||||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||||
# todo 注释 直接使用的变量 | # todo 注释 直接使用的变量 | ||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
@@ -5,13 +5,13 @@ | |||||
import os | import os | ||||
import json | import json | ||||
import sys | import sys | ||||
import subprocess | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.envs.imports import SUPPORT_BACKENDS | from fastNLP.envs.imports import SUPPORT_BACKENDS | ||||
from fastNLP.envs.utils import _module_available | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | |||||
def _set_backend(): | def _set_backend(): | ||||
""" | """ | ||||
@@ -56,17 +56,18 @@ def _set_backend(): | |||||
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | ||||
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | # 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | ||||
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
if user_visible_devices is not None: | |||||
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | # 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | ||||
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | # 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | ||||
# 我们需要从中找到真正使用的设备编号 | # 我们需要从中找到真正使用的设备编号 | ||||
user_visible_devices = user_visible_devices.split(",") | user_visible_devices = user_visible_devices.split(",") | ||||
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | ||||
else: | else: | ||||
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||||
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||||
# 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||||
range(get_gpu_count()) | |||||
))) | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) | |||||
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | ||||
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | ||||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | elif 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
@@ -78,7 +79,9 @@ def _set_backend(): | |||||
else: | else: | ||||
# 没有设置的话限制在单卡上,防止多进程时占用别的卡 | # 没有设置的话限制在单卡上,防止多进程时占用别的卡 | ||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||||
range(get_gpu_count()) | |||||
))) | |||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
@@ -36,8 +36,7 @@ def set_env_on_import_torch(): | |||||
# TODO paddle may need set this | # TODO paddle may need set this | ||||
def set_env_on_import_paddle(): | def set_env_on_import_paddle(): | ||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||||
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
and "PADDLE_RANK_IN_NODE" in os.environ: | and "PADDLE_RANK_IN_NODE" in os.environ: | ||||
# 检测到了分布式环境的环境变量 | # 检测到了分布式环境的环境变量 | ||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | ||||
@@ -3,6 +3,7 @@ from typing import Callable | |||||
import importlib | import importlib | ||||
from pkg_resources import DistributionNotFound | from pkg_resources import DistributionNotFound | ||||
from packaging.version import Version | from packaging.version import Version | ||||
import subprocess | |||||
import pkg_resources | import pkg_resources | ||||
@@ -46,3 +47,15 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: | |||||
if use_base_version: | if use_base_version: | ||||
pkg_version = Version(pkg_version.base_version) | pkg_version = Version(pkg_version.base_version) | ||||
return op(pkg_version, Version(version)) | return op(pkg_version, Version(version)) | ||||
def get_gpu_count(): | |||||
""" | |||||
利用命令行获取gpu数目的函数 | |||||
:return: gpu数目,如果没有显卡设备则为-1 | |||||
""" | |||||
try: | |||||
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) | |||||
# 经分割后还要除去头部和尾部的换行符 | |||||
return len(lines.split(b"\n")) - 2 | |||||
except: | |||||
return -1 |
@@ -251,10 +251,10 @@ class DataBundle: | |||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||||
.. note:: | .. note:: | ||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
@@ -285,7 +285,7 @@ class DataBundle: | |||||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||||
对DataBundle中所有的dataset使用apply方法 | 对DataBundle中所有的dataset使用apply方法 | ||||
@@ -309,10 +309,10 @@ class DataBundle: | |||||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True): | progress_desc: str = '', show_progress_bar: bool = True): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||||
.. note:: | .. note:: | ||||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param paths: | :param paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -164,7 +164,7 @@ class GraphBuilderBase: | |||||
def build_graph_from_file(self, path: str): | def build_graph_from_file(self, path: str): | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param path: | :param path: | ||||
:return: scipy_sparse_matrix | :return: scipy_sparse_matrix | ||||
@@ -33,7 +33,7 @@ class Pipe: | |||||
def process_from_file(self, paths: str) -> DataBundle: | def process_from_file(self, paths: str) -> DataBundle: | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param str paths: | :param str paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -1,7 +1,7 @@ | |||||
import pytest | import pytest | ||||
from functools import reduce | from functools import reduce | ||||
from fastNLP.core.callbacks.callback_events import Filter | |||||
from fastNLP.core.callbacks.callback_events import Events, Filter | |||||
class TestFilter: | class TestFilter: | ||||
@@ -10,7 +10,7 @@ import re | |||||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
@@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( | |||||
from fastNLP.core.callbacks.callback_events import Events | from fastNLP.core.callbacks.callback_events import Events | ||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -0,0 +1,93 @@ | |||||
""" | |||||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||||
看看有没有用pytest执行的机会 | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||||
""" | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.callbacks import Callback | |||||
import paddle | |||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||||
@dataclass | |||||
class MNISTTrainFleetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -1 | |||||
def test_trainer_fleet( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs, | |||||
): | |||||
model = PaddleNormalModel_Classification_1( | |||||
num_labels=MNISTTrainFleetConfig.num_labels, | |||||
feature_dimension=MNISTTrainFleetConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
validate_dataloaders = val_dataloader | |||||
validate_every = MNISTTrainFleetConfig.validate_every | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="logs", | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
driver = "fleet" | |||||
device = [0,2,3] | |||||
# driver = "paddle" | |||||
# device = 2 | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
test_trainer_fleet( | |||||
driver=driver, | |||||
device=device, | |||||
callbacks=callbacks, | |||||
n_epochs=5, | |||||
) |
@@ -0,0 +1,98 @@ | |||||
""" | |||||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||||
并且自己初始化了 fleet | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||||
""" | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.callbacks import Callback | |||||
import paddle | |||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
import paddle.distributed.fleet as fleet | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | |||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||||
@dataclass | |||||
class MNISTTrainFleetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -1 | |||||
def test_trainer_fleet( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs, | |||||
): | |||||
fleet.init(is_collective=True) | |||||
model = PaddleNormalModel_Classification_2( | |||||
num_labels=MNISTTrainFleetConfig.num_labels, | |||||
feature_dimension=MNISTTrainFleetConfig.feature_dimension, | |||||
) | |||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||||
model = fleet.distributed_model(model) | |||||
optimizers = fleet.distributed_optimizer(optimizers) | |||||
train_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
validate_dataloaders = val_dataloader | |||||
validate_every = MNISTTrainFleetConfig.validate_every | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="logs", | |||||
data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}" | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
driver = "fleet" | |||||
device = [0,2,3] | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
test_trainer_fleet( | |||||
driver=driver, | |||||
device=device, | |||||
callbacks=callbacks, | |||||
n_epochs=30, | |||||
) |
@@ -0,0 +1,25 @@ | |||||
import pytest | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.callbacks import Events | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_without_evaluator(): | |||||
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10)) | |||||
def fn1(trainer): | |||||
pass | |||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||||
def fn2(trainer, batch, indices): | |||||
pass | |||||
with pytest.raises(AssertionError): | |||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||||
def fn3(trainer, batch): | |||||
pass | |||||
@@ -98,14 +98,16 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | ||||
@pytest.mark.parametrize("validate_every", [-3]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | callbacks, | ||||
validate_every, | |||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator( | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
validate_every=validate_every, | |||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
@@ -143,7 +145,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
accumulation_steps, | accumulation_steps, | ||||
n_epochs=6, | n_epochs=6, | ||||
): | ): | ||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] | |||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | |||||
def test_trainer_validate_every( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=6, | |||||
): | |||||
def validate_every(trainer): | |||||
if trainer.global_forward_batches % 10 == 0: | |||||
print(trainer) | |||||
print("\nfastNLP test validate every.\n") | |||||
print(trainer.global_forward_batches) | |||||
return True | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
output_from_new_proc="all", | |||||
validate_every=validate_every | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -10,7 +10,7 @@ from typing import Any | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
@@ -70,7 +70,7 @@ def model_and_optimizers(request): | |||||
trainer_params.output_mapping = None | trainer_params.output_mapping = None | ||||
# elif request.param == 1: | # elif request.param == 1: | ||||
# model = | |||||
return trainer_params | return trainer_params | ||||
@@ -254,7 +254,7 @@ def test_trainer_on_exception( | |||||
): | ): | ||||
from fastNLP.core.callbacks.callback_events import Events | from fastNLP.core.callbacks.callback_events import Events | ||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == cur_rank: | if trainer.driver.get_local_rank() == cur_rank: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): | |||||
subprocess.check_call(command) | subprocess.check_call(command) | ||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||||
@magic_argv_env_context | |||||
def test_torch_wo_auto_param_call( | |||||
driver, | |||||
device, | |||||
n_epochs=10, | |||||
): | |||||
model = TorchNormalModel_Classification_3( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
optimizers = SGD(model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
n_epochs=n_epochs, | |||||
model_wo_auto_param_call=True, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -1,83 +1,103 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
from fastNLP.envs.set_backend import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
set_env_on_import_paddle() | |||||
set_env("paddle") | |||||
import paddle | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | |||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
import paddle | |||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver("torch") | |||||
driver = initialize_paddle_driver("torch", 0, model) | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||||
["cpu", "gpu:0", 0, [1]] | |||||
) | ) | ||||
def test_get_single_device(device): | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
""" | """ | ||||
测试正常情况下初始化PaddleSingleDriver的情况 | 测试正常情况下初始化PaddleSingleDriver的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||||
[0, 1] | |||||
) | ) | ||||
def test_get_single_device_with_visiblde_devices(device): | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["fleet"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_fleet_2(driver, device): | |||||
""" | """ | ||||
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 | |||||
测试 fleet 多卡的初始化情况 | |||||
""" | """ | ||||
# TODO | |||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleSingleDriver) | |||||
assert isinstance(driver, PaddleFleetDriver) | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[[1, 2, 3]] | |||||
[[0, 2, 3], -1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | ) | ||||
def test_get_fleet(device): | |||||
@magic_argv_env_context | |||||
def test_get_fleet(driver, device): | |||||
""" | """ | ||||
测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | |||||
[[1,2,3]] | |||||
("driver", "device"), | |||||
[("fleet", "cpu")] | |||||
) | ) | ||||
def test_get_fleet(device): | |||||
@magic_argv_env_context | |||||
def test_get_fleet_cpu(driver, device): | |||||
""" | """ | ||||
测试 launch 启动 fleet 多卡的初始化情况 | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | """ | ||||
# TODO | |||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
def test_device_out_of_range(device): | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | |||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
""" | """ | ||||
pass | |||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_paddle_driver(driver, device, model) |
@@ -1,262 +0,0 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
import paddle | |||||
from paddle.io import Dataset, DataLoader | |||||
class Net(paddle.nn.Layer): | |||||
def __init__(self): | |||||
super(Net, self).__init__() | |||||
self.fc1 = paddle.nn.Linear(784, 64) | |||||
self.fc2 = paddle.nn.Linear(64, 32) | |||||
self.fc3 = paddle.nn.Linear(32, 10) | |||||
self.fc4 = paddle.nn.Linear(10, 10) | |||||
def forward(self, x): | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
x = self.fc4(x) | |||||
return x | |||||
class PaddleDataset(Dataset): | |||||
def __init__(self): | |||||
super(PaddleDataset, self).__init__() | |||||
self.items = [paddle.rand((3, 4)) for i in range(320)] | |||||
def __len__(self): | |||||
return len(self.items) | |||||
def __getitem__(self, idx): | |||||
return self.items[idx] | |||||
class TorchNet(torch.nn.Module): | |||||
def __init__(self): | |||||
super(TorchNet, self).__init__() | |||||
self.torch_fc1 = torch.nn.Linear(10, 10) | |||||
self.torch_softmax = torch.nn.Softmax(0) | |||||
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) | |||||
self.torch_tensor = torch.ones(3, 3) | |||||
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | |||||
class TorchDataset(torch.utils.data.Dataset): | |||||
def __init__(self): | |||||
super(TorchDataset, self).__init__() | |||||
self.items = [torch.ones(3, 4) for i in range(320)] | |||||
def __len__(self): | |||||
return len(self.items) | |||||
def __getitem__(self, idx): | |||||
return self.items[idx] | |||||
class PaddleDriverTestCase(unittest.TestCase): | |||||
""" | |||||
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||||
""" | |||||
def setUp(self): | |||||
model = Net() | |||||
self.driver = PaddleDriver(model) | |||||
def test_check_single_optimizer_legacy(self): | |||||
""" | |||||
测试传入单个optimizer时的表现 | |||||
""" | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
# 传入torch的optimizer时,应该报错ValueError | |||||
with self.assertRaises(ValueError) as cm: | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legacy(self): | |||||
""" | |||||
测试传入optimizer list的表现 | |||||
""" | |||||
optimizers = [ | |||||
paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
] | |||||
with self.assertRaises(ValueError) as cm: | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legacy_in_train(self): | |||||
""" | |||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建torch的dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legacy_in_test(self): | |||||
""" | |||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建torch的dataloader | |||||
train_loader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试tensor_to_numeric函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = paddle.to_tensor(3) | |||||
res = PaddleDriver.tensor_to_numeric(tensor) | |||||
self.assertEqual(res, 3) | |||||
tensor = paddle.rand((3, 4)) | |||||
res = PaddleDriver.tensor_to_numeric(tensor) | |||||
self.assertListEqual(res, tensor.tolist()) | |||||
# 张量list | |||||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
res = PaddleDriver.tensor_to_numeric(tensor_list) | |||||
self.assertTrue(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
self.assertListEqual(res, tensor_list) | |||||
# 张量tuple | |||||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||||
res = PaddleDriver.tensor_to_numeric(tensor_tuple) | |||||
self.assertTrue(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
self.assertTupleEqual(res, tensor_tuple) | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = PaddleDriver.tensor_to_numeric(tensor_dict) | |||||
self.assertIsInstance(res, dict) | |||||
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) | |||||
self.assertIsInstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
self.assertListEqual(r, d.tolist()) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
self.assertListEqual(r, d.tolist()) | |||||
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
self.assertTrue(self.driver.model.training) | |||||
self.driver.set_model_mode("eval") | |||||
self.assertFalse(self.driver.model.training) | |||||
# 应该报错 | |||||
with self.assertRaises(AssertionError) as cm: | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0") | |||||
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) | |||||
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
PaddleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
def test_get_dataloader_args(self): | |||||
""" | |||||
测试get_dataloader_args | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
res = PaddleDriver.get_dataloader_args(dataloader) |
@@ -1,19 +1,19 @@ | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import pytest | import pytest | ||||
from pathlib import Path | |||||
from fastNLP.envs.set_backend import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from fastNLP.core import synchronize_safe_rm | |||||
set_env_on_import_paddle() | |||||
set_env("paddle") | |||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||||
from fastNLP.core import synchronize_safe_rm | |||||
import torch | |||||
############################################################################ | ############################################################################ | ||||
@@ -26,38 +26,116 @@ def generate_random_driver(features, labels): | |||||
""" | """ | ||||
生成driver | 生成driver | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(labels, features) | |||||
model = PaddleNormalModel_Classification_1(labels, features) | |||||
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | ||||
driver = PaddleSingleDriver(model) | |||||
driver = PaddleSingleDriver(model, device="cpu") | |||||
driver.set_optimizers(opt) | driver.set_optimizers(opt) | ||||
driver.setup() | |||||
return driver | return driver | ||||
@pytest.fixture | @pytest.fixture | ||||
def prepare_test_save_load(): | def prepare_test_save_load(): | ||||
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8) | |||||
dataset = PaddleRandomMaxDataset(320, 10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | return driver1, driver2, dataloader | ||||
def test_save_and_load(prepare_test_save_load): | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_with_randombatchsampler(only_state_dict): | |||||
""" | """ | ||||
测试save和load函数 | |||||
TODO optimizer的state_dict为空,暂时不测试 | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
dataset = PaddleRandomMaxDataset(80, 10) | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
) | |||||
# TODO 断点重训完善后在这里迭代几次 | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
if only_state_dict: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
replaced_loader = states["dataloader"] | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||||
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] | |||||
# 3. 检查 model 的参数是否被正确加载 | |||||
for batch in dataloader: | |||||
res1 = driver1.validate_step(batch) | |||||
res2 = driver2.validate_step(batch) | |||||
driver1.save(path, {}) | |||||
driver2.load(path) | |||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||||
# 4. 检查 batch_idx | |||||
# TODO | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_with_randomsampler(only_state_dict): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
dataset = PaddleRandomMaxDataset(80, 10) | |||||
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(dataset, True) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
# TODO 断点重训完善后在这里迭代几次 | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
if only_state_dict: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
replaced_loader = states["dataloader"] | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 model 的参数是否被正确加载 | |||||
for batch in dataloader: | for batch in dataloader: | ||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
# 4. 检查 batch_idx | |||||
# TODO | |||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
@@ -67,13 +145,14 @@ def test_save_and_load_state_dict(prepare_test_save_load): | |||||
TODO optimizer的state_dict为空,暂时不测试 | TODO optimizer的state_dict为空,暂时不测试 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
path = "dict" | |||||
driver1, driver2, dataloader = prepare_test_save_load | driver1, driver2, dataloader = prepare_test_save_load | ||||
driver1.save_model(path) | driver1.save_model(path) | ||||
driver2.model.load_dict(driver2.load_model(path)) | |||||
driver2.load_model(path) | |||||
for batch in dataloader: | for batch in dataloader: | ||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
@@ -87,19 +166,22 @@ def test_save_and_load_whole_model(prepare_test_save_load): | |||||
TODO optimizer的state_dict为空,暂时不测试 | TODO optimizer的state_dict为空,暂时不测试 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
path = "model" | |||||
driver1, driver2, dataloader = prepare_test_save_load | driver1, driver2, dataloader = prepare_test_save_load | ||||
driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) | |||||
driver2.model = driver2.load_model(path, load_dict=False) | |||||
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) | |||||
driver2.load_model(path, only_state_dict=False) | |||||
for batch in dataloader: | for batch in dataloader: | ||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
synchronize_safe_rm(path + ".pdiparams") | |||||
synchronize_safe_rm(path + ".pdiparams.info") | |||||
synchronize_safe_rm(path + ".pdmodel") | |||||
class TestSingleDeviceFunction: | class TestSingleDeviceFunction: | ||||
@@ -109,8 +191,8 @@ class TestSingleDeviceFunction: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
model = PaddleNormalModel_Classification(10, 784) | |||||
cls.driver = PaddleSingleDriver(model) | |||||
model = PaddleNormalModel_Classification_1(10, 784) | |||||
cls.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | def test_unwrap_model(self): | ||||
""" | """ | ||||
@@ -125,22 +207,6 @@ class TestSingleDeviceFunction: | |||||
self.driver.check_evaluator_mode("validate") | self.driver.check_evaluator_mode("validate") | ||||
self.driver.check_evaluator_mode("test") | self.driver.check_evaluator_mode("test") | ||||
def test_get_model_device_cpu(self): | |||||
""" | |||||
测试get_model_device | |||||
""" | |||||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu") | |||||
device = self.driver.get_model_device() | |||||
assert device == "cpu", device | |||||
def test_get_model_device_gpu(self): | |||||
""" | |||||
测试get_model_device | |||||
""" | |||||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0") | |||||
device = self.driver.get_model_device() | |||||
assert device == "gpu:0", device | |||||
def test_is_distributed(self): | def test_is_distributed(self): | ||||
assert self.driver.is_distributed() == False | assert self.driver.is_distributed() == False | ||||
@@ -151,18 +217,420 @@ class TestSingleDeviceFunction: | |||||
""" | """ | ||||
self.driver.move_data_to_device(paddle.rand((32, 64))) | self.driver.move_data_to_device(paddle.rand((32, 64))) | ||||
@pytest.mark.parametrize( | |||||
"dist_sampler", | |||||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"reproducible", | |||||
[True, False] | |||||
) | |||||
def test_repalce_sampler(self, dist_sampler, reproducible): | |||||
class TestSetDistReproDataloder: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = PaddleNormalDataset(20) | |||||
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
self.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_set_dist_repro_dataloader_with_reproducible_false(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | """ | ||||
测试replace_sampler函数 | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is self.dataloader | |||||
def test_set_dist_repro_dataloader_with_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | |||||
""" | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == self.dataloader.drop_last | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||||
""" | """ | ||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dist = RandomSampler(self.dataset, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | |||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 | |||||
already_seen_idx = set() | |||||
for idx, batch in replaced_loader: | |||||
already_seen_idx.update(batch) | |||||
if idx >= 1: | |||||
break | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
print(sampler_states["data_idx"]) | |||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
replaced_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||||
class TestPaddleDriverFunctions: | |||||
""" | |||||
使用 PaddleSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
self.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_check_single_optimizer_legality(self): | |||||
""" | |||||
测试传入单个optimizer时的表现 | |||||
""" | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
# 传入torch的optimizer时,应该报错ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入optimizer list的表现 | |||||
""" | |||||
optimizers = [ | |||||
paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
] | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legality_in_train(self): | |||||
""" | |||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建torch的dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legality_in_test(self): | |||||
""" | |||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
} | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | |||||
dataloader = { | |||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建torch的dataloader | |||||
train_loader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试tensor_to_numeric函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = paddle.to_tensor(3) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = paddle.rand((3, 4)) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.training | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.training | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
assert self.driver.model.linear1.weight.place.is_cpu_place() | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||||
assert self.driver.model.linear1.weight.place.is_gpu_place() | |||||
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
PaddleSingleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = PaddleNormalDataset() | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=RandomBatchSampler( | |||||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | |||||
batch_size, | |||||
drop_last, | |||||
) | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = PaddleNormalDataset() | |||||
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) | |||||
batch_sampler.sampler = RandomSampler(dataset, shuffle) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=batch_sampler, | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last |
@@ -1,4 +1,56 @@ | |||||
import unittest | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers.paddle_driver.utils import ( | |||||
get_device_from_visible, | |||||
replace_batch_sampler, | |||||
replace_sampler, | |||||
) | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
import paddle | import paddle | ||||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler | |||||
from paddle.io import DataLoader, BatchSampler | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
@pytest.mark.parametrize( | |||||
("user_visible_devices, cuda_visible_devices, device, output_type, correct"), | |||||
( | |||||
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), | |||||
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), | |||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), | |||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), | |||||
("3,4,5,6", "3,5", 0, int, 0), | |||||
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), | |||||
) | |||||
) | |||||
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices | |||||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices | |||||
res = get_device_from_visible(device, output_type) | |||||
assert res == correct | |||||
def test_replace_batch_sampler(): | |||||
dataset = PaddleNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | |||||
assert len(replaced_loader.dataset) == len(dataset) | |||||
assert replaced_loader.batch_sampler.batch_size == 16 | |||||
def test_replace_sampler(): | |||||
dataset = PaddleNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -6,13 +6,16 @@ import logging | |||||
import re | import re | ||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
from fastNLP.core.log.logger import logger | |||||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||||
# 测试 TorchDDPDriver; | # 测试 TorchDDPDriver; | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_1(): | |||||
@recover_logger | |||||
def test_add_file_ddp_1_torch(): | |||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | ||||
@@ -56,11 +59,11 @@ def test_add_file_ddp_1(): | |||||
synchronize_safe_rm(filepath) | synchronize_safe_rm(filepath) | ||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_2(): | |||||
@recover_logger | |||||
def test_add_file_ddp_2_torch(): | |||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | ||||
""" | """ | ||||
@@ -103,14 +106,14 @@ def test_add_file_ddp_2(): | |||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_3(): | |||||
@recover_logger | |||||
def test_add_file_ddp_3_torch(): | |||||
""" | """ | ||||
path = None; | path = None; | ||||
@@ -155,10 +158,10 @@ def test_add_file_ddp_3(): | |||||
synchronize_safe_rm(file) | synchronize_safe_rm(file) | ||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_4(): | |||||
@recover_logger | |||||
def test_add_file_ddp_4_torch(): | |||||
""" | """ | ||||
测试 path 是文件夹; | 测试 path 是文件夹; | ||||
""" | """ | ||||
@@ -200,7 +203,6 @@ def test_add_file_ddp_4(): | |||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -209,12 +211,11 @@ def test_add_file_ddp_4(): | |||||
class TestLogger: | class TestLogger: | ||||
msg = 'some test log msg' | msg = 'some test log msg' | ||||
@recover_logger | |||||
def test_add_file_1(self): | def test_add_file_1(self): | ||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
filepath = path.joinpath('log.txt') | filepath = path.joinpath('log.txt') | ||||
@@ -225,14 +226,12 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_2(self): | def test_add_file_2(self): | ||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | origin_path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
@@ -245,14 +244,12 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(origin_path) | synchronize_safe_rm(origin_path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_3(self): | def test_add_file_3(self): | ||||
""" | """ | ||||
测试 path 是 None; | 测试 path 是 None; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | handler = logger.add_file() | ||||
logger.info(self.msg) | logger.info(self.msg) | ||||
@@ -264,14 +261,12 @@ class TestLogger: | |||||
line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
assert self.msg in line | assert self.msg in line | ||||
file.unlink() | file.unlink() | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_4(self): | def test_add_file_4(self): | ||||
""" | """ | ||||
测试 path 是文件夹; | 测试 path 是文件夹; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
handler = logger.add_file(path) | handler = logger.add_file(path) | ||||
@@ -285,16 +280,21 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_stdout(self, capsys): | def test_stdout(self, capsys): | ||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | handler = logger.set_stdout(stdout="raw") | ||||
logger.info(self.msg) | logger.info(self.msg) | ||||
logger.debug('aabbc') | logger.debug('aabbc') | ||||
captured = capsys.readouterr() | captured = capsys.readouterr() | ||||
assert "some test log msg\n" == captured.out | assert "some test log msg\n" == captured.out | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_warning_once(self, capsys): | |||||
logger.warning_once('#') | |||||
logger.warning_once('#') | |||||
logger.warning_once('@') | |||||
captured = capsys.readouterr() | |||||
assert captured.out.count('#') == 1 | |||||
assert captured.out.count('@') == 1 | |||||
@@ -3,6 +3,7 @@ from array import array | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | |||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | assert isinstance(_get_re_batchsampler, RandomBatchSampler) | ||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
"sampler_type": "RandomBatchSampler"} | "sampler_type": "RandomBatchSampler"} | ||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | @pytest.mark.parametrize('drop_last', [True, False]) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | dataset = DatasetWithVaryLength(num_of_data=num_samples) | ||||
batch_size = 6 | batch_size = 6 | ||||
if num_replica*batch_size > num_samples: | |||||
if num_replicas*batch_size > num_samples: | |||||
return | return | ||||
num_batch_per_bucket = 10 | num_batch_per_bucket = 10 | ||||
samplers = [] | samplers = [] | ||||
lengths = [] | lengths = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | ||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | ||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
lengths.append(len(list(iter(sampler)))) | lengths.append(len(list(iter(sampler)))) | ||||
assert len(set(lengths))==1 | assert len(set(lengths))==1 | ||||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||||
bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||||
for bs in zip(*samplers): | for bs in zip(*samplers): | ||||
diff = max(chain(*bs)) - min(chain(*bs)) | diff = max(chain(*bs)) - min(chain(*bs)) | ||||
assert diff <= bucket_diff | assert diff <= bucket_diff | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
""" | |||||
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||||
偏多 | |||||
:return: | |||||
""" | |||||
batch_size = 6 | |||||
num_batch_per_bucket = 10 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
for i in range(num_replicas): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for batchs in zip(*samplers): | |||||
batch = chain(*batchs) | |||||
already_seen_set.update(batch) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||||
dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |
@@ -3,6 +3,7 @@ import pytest | |||||
from functools import partial | from functools import partial | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||||
assert seen <= 1 if pad else seen == 0 | assert seen <= 1 if pad else seen == 0 | ||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | assert seen_in_other_rank<=1 # 因为pad可能重复 | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||||
# 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
for i in range(num_replicas): | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for idxes in zip(*samplers): | |||||
already_seen_set.update(idxes) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.add(batch) | |||||
assert len(already_seen_set) == len(dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.set_epoch(0) | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for idx in sampler: | |||||
already_seen_set.add(idx) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for idx in sampler: | |||||
already_seen_set.add(idx) | |||||
assert len(already_seen_set)==len(dataset) | |||||
class TestRandomSampler: | class TestRandomSampler: | ||||
# 测试单卡; | # 测试单卡; | ||||
@@ -386,7 +444,7 @@ class TestSortedSampler: | |||||
assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -540,7 +598,7 @@ class TestSequentialSampler: | |||||
assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||||
indexes = set(sampler) | indexes = set(sampler) | ||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
def test_multi(self, num_replica, num_of_data, shuffle): | def test_multi(self, num_replica, num_of_data, shuffle): | ||||
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||||
indexes = list(sampler) | indexes = list(sampler) | ||||
assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||||
indexes = list(sampler) | indexes = list(sampler) | ||||
assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -0,0 +1,187 @@ | |||||
from functools import partial | |||||
import pytest | |||||
from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg | |||||
from fastNLP.core.metrics import Metric | |||||
class TestAutoParamCall: | |||||
def test_basic(self): | |||||
def fn(x): | |||||
return x | |||||
x = {'x': 3, 'y': 4} | |||||
r = auto_param_call(fn, x) | |||||
assert r==3 | |||||
xs = [] | |||||
for i in range(10): | |||||
xs.append({f'x{i}': i}) | |||||
def fn(x0, x1, x2, x3): | |||||
return x0 + x1 + x2 + x3 | |||||
r = auto_param_call(fn, *xs) | |||||
assert r == 0 + 1+ 2+ 3 | |||||
def fn(chongfu1, chongfu2, buChongFu): | |||||
pass | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, | |||||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||||
assert 'The following key present in several inputs' in exc_info.value.args[0] | |||||
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] | |||||
# 没用到不报错 | |||||
def fn(chongfu1, buChongFu): | |||||
pass | |||||
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, | |||||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||||
# 可以定制signature_fn | |||||
def fn1(**kwargs): | |||||
kwargs.pop('x') | |||||
kwargs.pop('y') | |||||
assert len(kwargs)==0 | |||||
def fn(x, y): | |||||
pass | |||||
x = {'x': 3, 'y': 4} | |||||
r = auto_param_call(fn1, x, signature_fn=fn) | |||||
# 没提供的时候报错 | |||||
def fn(meiti1, meiti2, tigong): | |||||
pass | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(fn, {'tigong':1}) | |||||
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] | |||||
# 默认值替换 | |||||
def fn(x, y=100): | |||||
return x + y | |||||
r = auto_param_call(fn, {'x': 10, 'y': 20}) | |||||
assert r==30 | |||||
assert auto_param_call(fn, {'x': 10, 'z': 20})==110 | |||||
# 测试mapping的使用 | |||||
def fn(x, y=100): | |||||
return x + y | |||||
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) | |||||
assert r==30 | |||||
# 测试不需要任何参数 | |||||
def fn(): | |||||
return 1 | |||||
assert 1 == auto_param_call(fn, {'x':1}) | |||||
# 测试调用类的方法没问题 | |||||
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) | |||||
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) | |||||
def test_msg(self): | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(self.call_this, {'x':1}) | |||||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(call_this_for_auto_param_call, {'x':1}) | |||||
assert __file__ in exc_info.value.args[0] | |||||
assert 'call_this_for_auto_param_call' in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(self.call_this_two, {'x':1}) | |||||
assert __file__ in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) | |||||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 | |||||
def call_this(self, x, y): | |||||
return x + y | |||||
def call_this_two(self, x, y, z=pytest, **kwargs): | |||||
return x + y | |||||
def test_metric_auto_param_call(self): | |||||
metric = AutoParamCallMetric() | |||||
with pytest.raises(BaseException): | |||||
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) | |||||
class AutoParamCallMetric(Metric): | |||||
def update(self, x): | |||||
pass | |||||
def call_this_for_auto_param_call(x, y): | |||||
return x + y | |||||
class TestCheckNumberOfParameters: | |||||
def test_validate_every(self): | |||||
def validate_every(trainer): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
# 无默认值,多了报错 | |||||
def validate_every(trainer, other): | |||||
pass | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
assert "2 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 有默认值ok | |||||
def validate_every(trainer, other=1): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
# 参数多了 | |||||
def validate_every(trainer): | |||||
pass | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 使用partial | |||||
def validate_every(trainer, other): | |||||
pass | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 如果存在 *args 或 *kwargs 不报错多的 | |||||
def validate_every(trainer, *args): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) | |||||
def validate_every(trainer, **kwargs): | |||||
pass | |||||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||||
# class 的方法删掉self | |||||
class InnerClass: | |||||
def demo(self, x): | |||||
pass | |||||
def no_param(self): | |||||
pass | |||||
def param_kwargs(self, **kwargs): | |||||
pass | |||||
inner = InnerClass() | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||||
def test_get_fun_msg(): | |||||
def demo(x): | |||||
pass | |||||
print(_get_fun_msg(_get_fun_msg)) |
@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
print("on_after_backward") | print("on_after_backward") | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
print("on_before_optimizer_step") | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
print("on_before_zero_grad") | print("on_before_zero_grad") | ||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
print("on_validate_begin") | print("on_validate_begin") | ||||
@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||||
x = torch.max(x, dim=-1)[1] | x = torch.max(x, dim=-1)[1] | ||||
return {"preds": x, "target": y} | return {"preds": x, "target": y} | ||||
class TorchNormalModel_Classification_2(nn.Module): | class TorchNormalModel_Classification_2(nn.Module): | ||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): | |||||
return {"loss": loss, "preds": x, "target": y} | return {"loss": loss, "preds": x, "target": y} | ||||
class TorchNormalModel_Classification_3(nn.Module): | |||||
""" | |||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||||
关闭 auto_param_call,forward 只有一个 batch 参数; | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(TorchNormalModel_Classification_3, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, batch): | |||||
x = batch["x"] | |||||
y = batch["y"] | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
loss = self.loss_fn(x, y) | |||||
x = torch.max(x, dim=-1)[1] | |||||
return {"loss": loss, "preds": x, "target": y} | |||||
@@ -2,34 +2,31 @@ import os | |||||
import sys | import sys | ||||
import __main__ | import __main__ | ||||
from functools import wraps | from functools import wraps | ||||
import inspect | |||||
from inspect import ismethod | from inspect import ismethod | ||||
import functools | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from io import StringIO | from io import StringIO | ||||
import time | import time | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.utils.utils import get_class_that_defined_method | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.log import logger | |||||
def get_class_that_defined_method(meth): | |||||
if isinstance(meth, functools.partial): | |||||
return get_class_that_defined_method(meth.func) | |||||
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||||
for cls in inspect.getmro(meth.__self__.__class__): | |||||
if meth.__name__ in cls.__dict__: | |||||
return cls | |||||
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||||
if inspect.isfunction(meth): | |||||
cls = getattr(inspect.getmodule(meth), | |||||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||||
None) | |||||
if isinstance(cls, type): | |||||
return cls | |||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||||
def recover_logger(fn): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs): | |||||
# 保存logger的状态 | |||||
handlers = [handler for handler in logger.handlers] | |||||
level = logger.level | |||||
res = fn(*args, **kwargs) | |||||
logger.handlers = handlers | |||||
logger.setLevel(level) | |||||
return res | |||||
return wrapper | |||||
def magic_argv_env_context(fn): | def magic_argv_env_context(fn): | ||||