@@ -16,6 +16,7 @@ __all__ = [ | |||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
"FitlogCallback", | "FitlogCallback", | ||||
"TimerCallback", | |||||
# collators | # collators | ||||
'Collator', | 'Collator', | ||||
@@ -45,9 +46,11 @@ __all__ = [ | |||||
'TorchDataLoader', | 'TorchDataLoader', | ||||
'PaddleDataLoader', | 'PaddleDataLoader', | ||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'OneflowDataLoader', | |||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
'prepare_oneflow_dataloader', | |||||
"prepare_dataloader", | "prepare_dataloader", | ||||
# dataset | # dataset | ||||
@@ -58,10 +61,13 @@ __all__ = [ | |||||
# drivers | # drivers | ||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"DeepSpeedDriver", | |||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
"PaddleFleetDriver", | "PaddleFleetDriver", | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
"JittorMPIDriver", | "JittorMPIDriver", | ||||
"OneflowSingleDriver", | |||||
"OneflowDDPDriver", | |||||
# log | # log | ||||
"logger", | "logger", | ||||
@@ -21,7 +21,9 @@ __all__ = [ | |||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
"FitlogCallback" | |||||
"FitlogCallback", | |||||
"TimerCallback" | |||||
] | ] | ||||
@@ -37,4 +39,4 @@ from .torch_callbacks import * | |||||
from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | ||||
from .fitlog_callback import FitlogCallback | from .fitlog_callback import FitlogCallback | ||||
from .timer_callback import TimerCallback |
@@ -10,7 +10,7 @@ from .callback_event import Event, Filter | |||||
class Callback: | class Callback: | ||||
r""" | r""" | ||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | |||||
实际使用的 callback 类,不管是 **fastNLP** 默认提供的一些 callback 实例,还是用户自己定制的 callback 类,都应该继承该基类; | |||||
callback 调用时机顺序大概如下:: | callback 调用时机顺序大概如下:: | ||||
Trainer.__init__(): | Trainer.__init__(): | ||||
@@ -41,17 +41,17 @@ class Callback: | |||||
finally: | finally: | ||||
on_train_end(trainer) | on_train_end(trainer) | ||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 | |||||
的时间调用。 | |||||
其它 callback 例如 **on_evaluate_begin(trainer)** / **on_evaluate_end(trainer, results)** / **on_save_model(trainer)** / | |||||
**on_load_model(trainer)** / **on_save_checkpoint(trainer)** / **on_load_checkpoint(trainer)** 将根据需要在 :meth:`Trainer.run <fastNLP.core.controllers.Trainer.run>` | |||||
中特定的时间调用。 | |||||
""" | """ | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
r""" | r""" | ||||
在 `Trainer` 初始化后会被触发; | |||||
在 ``Trainer`` 初始化后会被触发; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param driver: ``Trainer`` 中的 ``driver`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param driver: :class:`~fastNLP.core.controllers.Trainer` 中的 ``driver`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -59,7 +59,7 @@ class Callback: | |||||
r""" | r""" | ||||
在 '预跑'检测 开始前会被触发; | 在 '预跑'检测 开始前会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -67,7 +67,7 @@ class Callback: | |||||
r""" | r""" | ||||
在 '预跑'检测 开始后会被触发; | 在 '预跑'检测 开始后会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; | :param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; | ||||
""" | """ | ||||
pass | pass | ||||
@@ -76,7 +76,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练开始前会被触发; | 在训练开始前会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -84,7 +84,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练完成后会被触发; | 在训练完成后会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -92,7 +92,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 开始前会被触发; | 在训练过程中的每一个 epoch 开始前会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -100,7 +100,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -108,7 +108,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中准备取出下一个 batch 的数据时触发 | 在训练过程中准备取出下一个 batch 的数据时触发 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -116,30 +116,30 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中拿到当前的 batch 数据后会被触发; | 在训练过程中拿到当前的 batch 数据后会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_begin(self, trainer, batch, indices): | def on_train_batch_begin(self, trainer, batch, indices): | ||||
r""" | r""" | ||||
在取得数据,执行完 ``input_mapping`` (如果 ``Trainer`` 传有该参数),并且移动 ``batch`` 中的 ``tensor`` 到了指定设备。 | |||||
在取得数据,执行完 ``input_mapping`` (如果 :class:`~fastNLP.core.controllers.Trainer` 传有该参数),并且移动 ``batch`` 中的张量到了指定设备之后会被触发。 | |||||
其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 | 其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 | ||||
如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 ``key`` 或 修改其中的 ``value`` 会影响到输入到 ``model`` 的中的 ``batch`` 数据。 | |||||
如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入模型的中的 ``batch`` 数据。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 | :param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 | ||||
:param list[int] indices: 当前的 ``batch`` 是 ``dataset`` 中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, | |||||
其它时候为 None 。 | |||||
:param list[int] indices: 当前的 ``batch`` 是数据集中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, | |||||
其它时候为 ``None`` 。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
r""" | r""" | ||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch 与 | |||||
global_forward_batches 累计加1操作之后会被触发。其中梯度更新、梯度置零操作会考虑 **accumulation_steps** ,所以不一定在当前 batch 会 | |||||
执行。 | 执行。 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -147,41 +147,42 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程遇到异常时调用。 | 在训练过程遇到异常时调用。 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param exception: 遭遇的异常; | :param exception: 遭遇的异常; | ||||
""" | """ | ||||
pass | pass | ||||
def on_save_model(self, trainer): | def on_save_model(self, trainer): | ||||
r""" | r""" | ||||
当调用 Trainer.save_model() 时调用,此刻模型还未保存。 | |||||
当调用 :meth:`Trainer.save_model() <fastNLP.core.controllers.Trainer.save_model>` 时调用,此刻模型还未保存。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_model(self, trainer): | def on_load_model(self, trainer): | ||||
r""" | r""" | ||||
当调用 Trainer.load_model() 加载模型时调用,此刻模型还未加载。 | |||||
当调用 :meth:`Trainer.load_model() <fastNLP.core.controllers.Trainer.load_model>` 加载模型时调用,此刻模型还未加载。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
r""" | r""" | ||||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 Trainer.save_checkpoint() 函数时),该函数用于保存当前 callback 在恢复需要的相关数据。 | |||||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 :meth:`Trainer.save_checkpoint() <fastNLP.core.controllers.Trainer.save_checkpoint>` | |||||
函数时),该函数用于保存当前 callback 在恢复时需要的相关数据。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
r""" | r""" | ||||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时, 此刻 Trainer 与 Driver 已经加载好自身 | |||||
的状态), 参数 states 为 Callback 在调用 on_save_checkpoint() 的返回值。 | |||||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 :meth:`Trainer.load_checkpoint() <fastNLP.core.controllers.Trainer.load_checkpoint>` | |||||
函数时, 此刻 Trainer 与 Driver 已经加载好自身的状态), 参数 states 为 Callback 在调用 :meth:`on_save_checkpoint` 的返回值。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param states: | :param states: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -190,7 +191,7 @@ class Callback: | |||||
r""" | r""" | ||||
在 backward 前执行。 | 在 backward 前执行。 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 | :param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 | ||||
""" | """ | ||||
pass | pass | ||||
@@ -198,54 +199,54 @@ class Callback: | |||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
r""" | r""" | ||||
在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, | 在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, | ||||
因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 ``step`` 各卡上梯度不一致的问题。 | |||||
因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_optimizers_step(self, trainer, optimizers): | def on_before_optimizers_step(self, trainer, optimizers): | ||||
r""" | r""" | ||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_optimizers_step(self, trainer, optimizers): | def on_after_optimizers_step(self, trainer, optimizers): | ||||
r""" | r""" | ||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
r""" | r""" | ||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_zero_grad(self, trainer, optimizers): | def on_after_zero_grad(self, trainer, optimizers): | ||||
r""" | r""" | ||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_evaluate_begin(self, trainer): | def on_evaluate_begin(self, trainer): | ||||
r""" | r""" | ||||
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | |||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||||
在将要进行 ``evaluate`` 时调用。如果是设置的以 step 数量或自定义地决定 evaluate 的频率,该接口是在 :meth:`on_train_batch_end` 之后 | |||||
进行调用。如果是以 epoch 数量决定调用时机,该接口是在 :meth:`on_train_epoch_end` 之后调用。 | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -253,17 +254,17 @@ class Callback: | |||||
r""" | r""" | ||||
结束 evaluate 时调用,并把 evaluate 的结果传入。 | 结束 evaluate 时调用,并把 evaluate 的结果传入。 | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param results: ``Trainer`` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param results: :class:`~fastNLP.core.controllers.Trainer` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||||
""" | """ | ||||
pass | pass | ||||
@property | @property | ||||
def callback_name(self): | def callback_name(self): | ||||
r""" | r""" | ||||
``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 ``on_load_checkpoint()`` 函数。 | |||||
``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 :meth:`on_load_checkpoint` 函数。 | |||||
:return: 返回用于区分该 ``callback`` 实例的 ``name``; | |||||
:return: 返回用于区分该 ``callback`` 实例的名称; | |||||
""" | """ | ||||
return self.__class__.__name__ | return self.__class__.__name__ | ||||
@@ -31,13 +31,13 @@ def check_legality(fn): | |||||
class Event: | class Event: | ||||
""" | """ | ||||
与 Trainer.on 函数配合使用,达到控制 callback 函数运行时机的目的。 | |||||
与 :meth:`Trainer.on` 函数配合使用,达到控制 callback 函数运行时机的目的。 | |||||
:param value: Trainer 的 callback 时机。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param value: Trainer 的 callback 时机; | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
""" | """ | ||||
every: Optional[int] | every: Optional[int] | ||||
once: Optional[int] | once: Optional[int] | ||||
@@ -53,416 +53,416 @@ class Event: | |||||
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once, | return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once, | ||||
self.filter_fn) | self.filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_after_trainer_initialized(every=None, once=None, filter_fn=None): | def on_after_trainer_initialized(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_after_trainer_initialized 时 | |||||
当 Trainer 运行到 :func:`on_after_trainer_initialized` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_sanity_check_begin(every=None, once=None, filter_fn=None): | def on_sanity_check_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_sanity_check_begin 时 | |||||
当 Trainer 运行到 :func:`on_sanity_check_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_sanity_check_end(every=None, once=None, filter_fn=None): | def on_sanity_check_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_sanity_check_end 时 | |||||
当 Trainer 运行到 :func:`on_sanity_check_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_begin(every=None, once=None, filter_fn=None): | def on_train_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_begin 时 | |||||
当 Trainer 运行到 :func:`on_train_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_end(every=None, once=None, filter_fn=None): | def on_train_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_end 时 | |||||
当 Trainer 运行到 :func:`on_train_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_epoch_begin(every=None, once=None, filter_fn=None): | def on_train_epoch_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_epoch_begin 时 | |||||
当 Trainer 运行到 :func:`on_train_epoch_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_epoch_end(every=None, once=None, filter_fn=None): | def on_train_epoch_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_epoch_end 时 | |||||
当 Trainer 运行到 :func:`on_train_epoch_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_fetch_data_begin(every=None, once=None, filter_fn=None): | def on_fetch_data_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_fetch_data_begin 时 | |||||
当 Trainer 运行到 :func:`on_fetch_data_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_fetch_data_end(every=None, once=None, filter_fn=None): | def on_fetch_data_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_fetch_data_end 时 | |||||
当 Trainer 运行到 :func:`on_fetch_data_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_batch_begin(every=None, once=None, filter_fn=None): | def on_train_batch_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_batch_begin 时 | |||||
当 Trainer 运行到 :func:`on_train_batch_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_train_batch_end(every=None, once=None, filter_fn=None): | def on_train_batch_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_train_batch_end 时 | |||||
当 Trainer 运行到 :func:`on_train_batch_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_exception(every=None, once=None, filter_fn=None): | def on_exception(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_exception 时 | |||||
当 Trainer 运行到 :func:`on_exception` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_save_model(every=None, once=None, filter_fn=None): | def on_save_model(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_save_model 时 | |||||
当 Trainer 运行到 :func:`on_save_model` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_load_model(every=None, once=None, filter_fn=None): | def on_load_model(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_load_model 时 | |||||
当 Trainer 运行到 :func:`on_load_model` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_save_checkpoint(every=None, once=None, filter_fn=None): | def on_save_checkpoint(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_save_checkpoint 时 | |||||
当 Trainer 运行到 :func:`on_save_checkpoint` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_load_checkpoint(every=None, once=None, filter_fn=None): | def on_load_checkpoint(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_load_checkpoint(every=None, once=None, filter_fn=None): | def on_load_checkpoint(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_before_backward(every=None, once=None, filter_fn=None): | def on_before_backward(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_before_backward 时 | |||||
当 Trainer 运行到 :func:`on_before_backward` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_after_backward(every=None, once=None, filter_fn=None): | def on_after_backward(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_after_backward 时 | |||||
当 Trainer 运行到 :func:`on_after_backward` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_before_optimizers_step(every=None, once=None, filter_fn=None): | def on_before_optimizers_step(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_before_optimizers_step 时 | |||||
当 Trainer 运行到 :func:`on_before_optimizers_step` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_after_optimizers_step(every=None, once=None, filter_fn=None): | def on_after_optimizers_step(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_after_optimizers_step 时 | |||||
当 Trainer 运行到 :func:`on_after_optimizers_step` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_before_zero_grad(every=None, once=None, filter_fn=None): | def on_before_zero_grad(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_before_zero_grad 时 | |||||
当 Trainer 运行到 :func:`on_before_zero_grad` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_after_zero_grad(every=None, once=None, filter_fn=None): | def on_after_zero_grad(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_after_zero_grad 时 | |||||
当 Trainer 运行到 :func:`on_after_zero_grad` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_evaluate_begin(every=None, once=None, filter_fn=None): | def on_evaluate_begin(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_evaluate_begin 时 | |||||
当 Trainer 运行到 :func:`on_evaluate_begin` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) | ||||
@staticmethod | @staticmethod | ||||
@check_legality | |||||
def on_evaluate_end(every=None, once=None, filter_fn=None): | def on_evaluate_end(every=None, once=None, filter_fn=None): | ||||
""" | """ | ||||
当 Trainer 运行到 on_evaluate_end 时 | |||||
当 Trainer 运行到 :func:`on_evaluate_end` 时触发; | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:param every: 每触发多少次才真正运行一次; | |||||
:param once: 在第一次运行后时候再次执行; | |||||
:param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 | |||||
`filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; | |||||
:return: | :return: | ||||
""" | """ | ||||
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) | return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) | ||||
class Filter: | class Filter: | ||||
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率。 | |||||
r""" | |||||
可以控制一个函数实际的运行频率的函数修饰器。 | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数是否只运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): | |||||
# check legality | # check legality | ||||
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) | check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) | ||||
if (every is None) and (once is None) and (filter_fn is None): | if (every is None) and (once is None) and (filter_fn is None): | ||||
@@ -75,12 +75,13 @@ class CallbackManager: | |||||
def __init__(self, callbacks: Optional[List[Callback]]): | def __init__(self, callbacks: Optional[List[Callback]]): | ||||
r""" | r""" | ||||
注意 callback 的调用顺序: | |||||
注意 callback 的调用顺序为: | |||||
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | ||||
2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | 2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | ||||
3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | 3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | ||||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 类; | |||||
:param callbacks: 初始化时可以传入的一系列 :class:`~fastNLP.Callback` 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 列表; | |||||
""" | """ | ||||
self._need_reproducible_sampler = False | self._need_reproducible_sampler = False | ||||
@@ -106,12 +107,9 @@ class CallbackManager: | |||||
def initialize_class_callbacks(self): | def initialize_class_callbacks(self): | ||||
r""" | r""" | ||||
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | |||||
在实际的运行过程中,我们会将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | |||||
一个个 callback 时机,也就是 `Event` 的类别; | 一个个 callback 时机,也就是 `Event` 的类别; | ||||
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | ||||
:param callbacks: | |||||
:return: | |||||
""" | """ | ||||
for each_callback in self.class_callbacks: | for each_callback in self.class_callbacks: | ||||
self._need_reproducible_sampler |= each_callback.need_reproducible_sampler | self._need_reproducible_sampler |= each_callback.need_reproducible_sampler | ||||
@@ -144,11 +142,12 @@ class CallbackManager: | |||||
用于断点重训的 callback 的保存函数; | 用于断点重训的 callback 的保存函数; | ||||
该函数主要涉及两个方面: | 该函数主要涉及两个方面: | ||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 | |||||
断点重训应当保存的状态; | |||||
1. callback 的状态的保存;我们会调用每一个 callback 的 :func:`on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 | |||||
断点重训应当保存的状态; | |||||
2. 每一个具体的 callback 函数的 filter 的状态; | 2. 每一个具体的 callback 函数的 filter 的状态; | ||||
:return: 一个包含上述内容的字典: | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:return: 一个包含上述内容的字典,格式如下: | |||||
.. code-block:: | .. code-block:: | ||||
{ | { | ||||
@@ -195,11 +194,10 @@ class CallbackManager: | |||||
def on_load_checkpoint(self, trainer, states: Dict): | def on_load_checkpoint(self, trainer, states: Dict): | ||||
r""" | r""" | ||||
用于断点重训的加载函数; | |||||
对应于断点重训的保存函数; | |||||
用于断点重训的加载函数,对应于断点重训的保存函数; | |||||
:param trainer: `Trainer` | |||||
:param states: 见 `on_save_checkpoint` 函数的返回值; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param states: 同 :func:`on_save_checkpoint` 函数的返回值; | |||||
""" | """ | ||||
# 1. 先恢复每一个具体的 callback 函数的 filter 的状态; | # 1. 先恢复每一个具体的 callback 函数的 filter 的状态; | ||||
@@ -24,36 +24,37 @@ class CheckpointCallback(Callback): | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | ||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | ||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None, | |||||
``model_save_fn`` 为 ``None`` ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 ``model_save_fn`` 不为 ``None``, | |||||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | 则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | ||||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | 的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | ||||
:param monitor: 监控的 metric 值。 | :param monitor: 监控的 metric 值。 | ||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param every_n_epochs: 多少个 epoch 保存一次。 | :param every_n_epochs: 多少个 epoch 保存一次。 | ||||
:param every_n_batches: 多少个 batch 保存一次。 | :param every_n_batches: 多少个 batch 保存一次。 | ||||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。如果为 False 则不会保存 {save_object}-last 文件 | |||||
:param topk: 保存 monitor 结果 topK 个。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 EarlyStopException 。 | |||||
:param last: 如果为 ``True`` ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。如果为 ``False`` 则不会保存 ``{save_object}-last`` 文件 | |||||
:param topk: 保存 monitor 结果中的 ``topk`` 个。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 :class:`~fastNLP.core.callbacks.EarlyStopException` 。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | :param larger_better: monitor 的值是否时越大越好。 | ||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | ||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有用,默认为 ``True`` 。 | |||||
:param kwargs: | :param kwargs: | ||||
""" | """ | ||||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | ||||
@@ -10,16 +10,16 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class EarlyStopCallback(HasMonitorCallback): | class EarlyStopCallback(HasMonitorCallback): | ||||
""" | """ | ||||
用于 early stop 的 callback 。当监控的结果连续多少次没有变好边 raise 一个 EarlyStopException 。 | |||||
用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 :class:`EarlyStopException` 。 | |||||
:param monitor: 监控的 metric 值。 | :param monitor: 监控的 metric 值。 | ||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | ||||
* 为 ``Callable`` | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | 的 ``monitor`` 值请返回 ``None`` 。 | ||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
@@ -14,20 +14,21 @@ if _module_available('fitlog'): | |||||
class FitlogCallback(HasMonitorCallback): | class FitlogCallback(HasMonitorCallback): | ||||
""" | """ | ||||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | 自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | ||||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 | |||||
``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 | |||||
``Callback`` ,例如 :class:`~fastNLP.CheckpointCallback` )下的文件夹名称。 | |||||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 | |||||
``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 | |||||
``Callback`` ,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback` )下的文件夹名称。 | |||||
:param monitor: 监控的 metric 值。 | :param monitor: 监控的 metric 值。 | ||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是越大越好。 | :param larger_better: 是否是越大越好。 | ||||
:param log_exception: 是否记录 ``exception`` 。 | :param log_exception: 是否记录 ``exception`` 。 | ||||
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | ||||
@@ -44,7 +45,7 @@ class FitlogCallback(HasMonitorCallback): | |||||
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | ||||
fitlog.debug() | fitlog.debug() | ||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME']) | |||||
fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME']) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | ||||
@@ -26,19 +26,19 @@ class CanItemDataType(ABC): | |||||
class ResultsMonitor: | class ResultsMonitor: | ||||
""" | """ | ||||
可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | |||||
可用于监控某个数值,并通过 :meth:`is_better_results` 等接口检测结果是否变得更好。 | |||||
:param monitor: 监控的 metric 值。 | |||||
:param monitor: 监控的 metric 值: | |||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
的 ``monitor`` 值请返回 ``None`` ; | |||||
:param larger_better: monitor 是否为越大越好; | |||||
""" | """ | ||||
def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): | def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): | ||||
self.set_monitor(monitor, larger_better) | self.set_monitor(monitor, larger_better) | ||||
@@ -60,7 +60,7 @@ class ResultsMonitor: | |||||
def itemize_results(self, results): | def itemize_results(self, results): | ||||
""" | """ | ||||
将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。 | |||||
执行结果中所有对象的 :meth:`item` 方法(如果没有则忽略),使得 Tensor 类型的数据转为 python 内置类型。 | |||||
:param results: | :param results: | ||||
:return: | :return: | ||||
@@ -69,10 +69,10 @@ class ResultsMonitor: | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | def get_monitor_value(self, results:Dict)->Union[float, None]: | ||||
""" | """ | ||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 最长公共字符串算法 匹配的方式寻找。 | |||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 **最长公共字符串算法** 匹配的方式寻找。 | |||||
:param results: 评测结果。 | |||||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||||
:param results: 评测结果; | |||||
:return: monitor 的值;如果为 ``None`` ,表明此次没有找到合适的monitor; | |||||
""" | """ | ||||
if len(results) == 0 or self.monitor is None: | if len(results) == 0 or self.monitor is None: | ||||
return None | return None | ||||
@@ -100,10 +100,10 @@ class ResultsMonitor: | |||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | ||||
""" | """ | ||||
检测 monitor_value 是否是更好的 | |||||
检测 ``monitor_value`` 是否是更好的 | |||||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||||
:param monitor_value: 待检查的 ``monitor_value`` 。如果为 ``None`` ,返回 False; | |||||
:param keep_if_better: 如果传入的 ``monitor_value`` 值更好,则将其保存下来; | |||||
:return: | :return: | ||||
""" | """ | ||||
if monitor_value is None: | if monitor_value is None: | ||||
@@ -115,10 +115,10 @@ class ResultsMonitor: | |||||
def is_better_results(self, results, keep_if_better=True): | def is_better_results(self, results, keep_if_better=True): | ||||
""" | """ | ||||
检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 | |||||
检测给定的 ``results`` 是否比上一次更好,如果本次 results 中没有找到相关的 monitor 返回 ``False``。 | |||||
:param results: evaluation 结果。 | |||||
:param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 | |||||
:param results: evaluation 结果; | |||||
:param keep_if_better: 当返回为 ``True`` 时,是否保存到 ``self.monitor_value`` 中; | |||||
:return: | :return: | ||||
""" | """ | ||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
@@ -128,7 +128,7 @@ class ResultsMonitor: | |||||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | ||||
""" | """ | ||||
传入的两个值中,是否monitor_value1的结果更好。 | |||||
传入的两个值中,是否 ``monitor_value1`` 的结果更好。 | |||||
:param monitor_value1: | :param monitor_value1: | ||||
:param monitor_value2: | :param monitor_value2: | ||||
@@ -149,7 +149,7 @@ class ResultsMonitor: | |||||
@property | @property | ||||
def monitor_name(self): | def monitor_name(self): | ||||
""" | """ | ||||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||||
返回 monitor 的名字,如果 monitor 是个 Callable 的函数,则返回该函数的名称。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -171,7 +171,7 @@ class ResultsMonitor: | |||||
@property | @property | ||||
def log_name(self) -> str: | def log_name(self) -> str: | ||||
""" | """ | ||||
内部用于打印信息使用 | |||||
内部用于打印当前类别信息使用 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -185,20 +185,20 @@ class ResultsMonitor: | |||||
class HasMonitorCallback(ResultsMonitor, Callback): | class HasMonitorCallback(ResultsMonitor, Callback): | ||||
""" | """ | ||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | ||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
(1)判断 monitor 合法性;(2)在需要时, 根据 trainer 的 monitor 设置自己的 monitor 名称。 | |||||
:param monitor: 监控的 metric 值。 | |||||
:param monitor: 监控的 metric 值: | |||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
的 ``monitor`` 值请返回 ``None`` ; | |||||
:param larger_better: monitor 是否为越大越好; | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 ``True`` ,且没检测到设置 monitor 会报错; | |||||
""" | """ | ||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | def __init__(self, monitor, larger_better, must_have_monitor=False): | ||||
super().__init__(monitor, larger_better) | super().__init__(monitor, larger_better) | ||||
@@ -230,20 +230,20 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||||
class ExecuteOnceBetterMonitor(HasMonitorCallback): | class ExecuteOnceBetterMonitor(HasMonitorCallback): | ||||
""" | """ | ||||
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | |||||
当监控的 ``monitor`` 结果更好的时候,调用 ``execute_fn`` 函数。 | |||||
:param monitor: 监控的 metric 值。 | |||||
:param monitor: 监控的 metric 值: | |||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 ``monitor`` 值(如果有设置); | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | |||||
的 ``monitor`` 值请返回 ``None`` ; | |||||
:param larger_better: monitor 是否是越大越好; | |||||
:param execute_fn: 一个可执行的函数,不接受任何参数,没有返回值。在 monitor 取得更好结果的时候会调用; | |||||
""" | """ | ||||
def __init__(self, monitor, larger_better, execute_fn): | def __init__(self, monitor, larger_better, execute_fn): | ||||
super().__init__(monitor, larger_better, must_have_monitor=True) | super().__init__(monitor, larger_better, must_have_monitor=True) | ||||
@@ -19,25 +19,25 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | ||||
最好的模型。 | 最好的模型。 | ||||
:param monitor: 监控的 metric 值。 | |||||
:param monitor: 监控的 metric 值: | |||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置); | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ; | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | |||||
的 ``monitor`` 值请返回 ``None`` ; | |||||
:param larger_better: 该 metric 值是否是越大越好; | |||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||||
:param only_state_dict: 是否只保存模型的参数。当 model_save_fn 不为空时,该值无效。 | |||||
:param model_save_fn: 保存 model 的函数,与 model_load_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的保存。 | |||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的加载。 | |||||
:param delete_after_train: 在训练结束后是否删掉模型。 | |||||
不同的机器均可访问当该路径。当 ``model_save_fn`` 不为 None 时该值一定不能为空; | |||||
:param only_state_dict: 是否只保存模型的参数。当 ``model_save_fn`` 不为空时,该值无效; | |||||
:param model_save_fn: 保存 model 的函数,与 ``model_load_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的保存; | |||||
:param model_load_fn: 加载 model 的函数,与 ``model_save_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的加载; | |||||
:param delete_after_train: 在训练结束后是否删掉模型; | |||||
""" | """ | ||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | ||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
@@ -105,14 +105,16 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | ||||
if self.real_save_folder: | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
# 如果是分布式且报错了,就不要加载了,防止barrier的问题 | |||||
if not (trainer.driver.is_distributed() and self.encounter_exception): | |||||
if self.real_save_folder: | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
if self.delete_after_after: | if self.delete_after_after: | ||||
if not self.encounter_exception: # 防止出现死锁。 | if not self.encounter_exception: # 防止出现死锁。 | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -7,11 +7,11 @@ __all__ = [ | |||||
class LRSchedCallback(Callback): | class LRSchedCallback(Callback): | ||||
""" | """ | ||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | |||||
根据 ``step_on`` 参数在合适的时机调用 scheduler 的 step 函数。 | |||||
:param scheduler: 实现了 step() 函数的对象 | |||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 | |||||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 | |||||
:param scheduler: 实现了 :meth:`step` 函数的对象; | |||||
:param step_on: 可选 ``['batch', 'epoch']`` 表示在何时调用 scheduler 的 step 函数。如果为 ``batch`` 的话在每次更新参数 | |||||
之前调用;如果为 ``epoch`` 则是在一个 epoch 运行结束后调用; | |||||
""" | """ | ||||
def __init__(self, scheduler, step_on:str='batch'): | def __init__(self, scheduler, step_on:str='batch'): | ||||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | ||||
@@ -19,7 +19,7 @@ class LRSchedCallback(Callback): | |||||
self.scheduler = scheduler | self.scheduler = scheduler | ||||
self.step_on = 0 if step_on == 'batch' else 1 | self.step_on = 0 if step_on == 'batch' else 1 | ||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
if self.step_on == 0: | if self.step_on == 0: | ||||
self.scheduler.step() | self.scheduler.step() | ||||
@@ -12,10 +12,10 @@ from .topk_saver import TopkSaver | |||||
class MoreEvaluateCallback(HasMonitorCallback): | class MoreEvaluateCallback(HasMonitorCallback): | ||||
""" | """ | ||||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||||
当评测时需要调用不同的 ``evaluate_fn`` (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 ``evaluate_fn`` ),只使用 Trainer | |||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 ``topk`` 以及 | |||||
``topk_monitor`` 等相关参数。可以通过 ``evaluate_every`` 或 ``watch_monitor`` 控制触发进行 evaluate 的条件。 | |||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: | 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: | ||||
@@ -30,7 +30,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | ||||
2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | 2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | ||||
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | ||||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||||
返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||||
.. note:: | .. note:: | ||||
@@ -45,32 +45,41 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次; | 该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次; | ||||
另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证; | 另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证; | ||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 | |||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 ``None`` ,``evaluate_every`` 失效。本参数的 | |||||
意义是,当检测到 Trainer 中 evaluate results 的 ``{watch_monitor}`` 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
取值: | |||||
1. ``str`` 类型,含义为监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 **最长公共字符串算法** 找到最 | |||||
匹配的那个作为 monitor ; | |||||
2. 一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 ``None`` ; | |||||
:param watch_monitor_larger_better: ``watch_monitor`` 是否越大越好; | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 :meth:`model.evaluate_step` 还是 | |||||
:meth:`model.forward`: | |||||
1. 如果该值是 ``None``,那么我们会默认使用 :meth:`model.evaluate_step` 当做前向传播的函数,如果 | |||||
在模型中没有找到该方法,则使用 :meth:`model.forward` 函数; | |||||
2. 如果为 ``str`` 类型,则尝试从 model 中寻找该方法,找不到则报错; | |||||
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | :param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | ||||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||||
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 topk 实现: | |||||
1. 为 ``-1`` 表示每次 evaluate 后都保存; | |||||
2. 为 ``0`` (默认),表示不保存; | |||||
3. 为整数,表示保存性能最好的 ``topk`` 个。 | |||||
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | ||||
:param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||||
:param topk_larger_better: ``topk_monitor`` 的值是否是越大越好。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
时间戳文件夹中。如果为 ``None`` ,默认使用当前文件夹。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 | |||||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | ||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param save_kwargs: dict。更多的保存相关的参数。 | |||||
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。 | |||||
如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有效,默认为 True 。 | |||||
:param save_kwargs: 一个字典,表示更多的保存相关的参数。 | |||||
:param kwargs: 其它与 :class:`~fastNLP.core.controllers.Evaluator` 相关的初始化参数,如果不传入,将从 :class:`~fastNLP.core.controllers.Trainer` 中获取。 | |||||
""" | """ | ||||
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | ||||
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | ||||
@@ -1,5 +1,4 @@ | |||||
import json | import json | ||||
import sys | |||||
from typing import Union | from typing import Union | ||||
__all__ = [ | __all__ = [ | ||||
@@ -16,8 +15,25 @@ from fastNLP.core.log import logger | |||||
class ProgressCallback(HasMonitorCallback): | class ProgressCallback(HasMonitorCallback): | ||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, | |||||
must_have_monitor=must_have_monitor) | |||||
self.best_monitor_epoch = -1 | |||||
self.best_monitor_step = -1 | |||||
self.best_results = None | |||||
def record_better_monitor(self, trainer, results): | |||||
self.best_monitor_step = trainer.global_forward_batches | |||||
self.best_monitor_epoch = trainer.cur_epoch_idx | |||||
self.best_results = self.itemize_results(results) | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
f_rich_progress.stop() | |||||
if self.best_monitor_epoch != -1: | |||||
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||||
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||||
if self.best_results is not None: | |||||
msg = msg + ' The evaluation result: \n' + str(self.best_results) | |||||
logger.info(msg) | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
@@ -44,21 +60,22 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
""" | """ | ||||
在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
在训练过程中打印 *rich* progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | ||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | ||||
* 为 ``Callable`` | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | 的 ``monitor`` 值请返回 ``None`` 。 | ||||
:param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
:param format_json: 是否格式化 json 再打印 | :param format_json: 是否格式化 json 再打印 | ||||
""" | """ | ||||
@@ -97,6 +114,7 @@ class RichCallback(ProgressCallback): | |||||
advance=None, completed=trainer.cur_epoch_idx, refresh=True) | advance=None, completed=trainer.cur_epoch_idx, refresh=True) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
super(RichCallback, self).on_train_end(trainer) | |||||
self.clear_tasks() | self.clear_tasks() | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
@@ -121,8 +139,8 @@ class RichCallback(ProgressCallback): | |||||
text_style = '' | text_style = '' | ||||
characters = '-' | characters = '-' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
text_style = '[bold]' | text_style = '[bold]' | ||||
@@ -131,8 +149,11 @@ class RichCallback(ProgressCallback): | |||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | ||||
f"Batch:{trainer.batch_idx_in_epoch}", | f"Batch:{trainer.batch_idx_in_epoch}", | ||||
style=rule_style, characters=characters) | style=rule_style, characters=characters) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
results = json.dumps(results) | |||||
self.progress_bar.console.print_json(results) | |||||
else: | else: | ||||
self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
@@ -149,26 +170,26 @@ class RichCallback(ProgressCallback): | |||||
class RawTextCallback(ProgressCallback): | class RawTextCallback(ProgressCallback): | ||||
""" | |||||
通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
""" | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
format_json=True): | format_json=True): | ||||
""" | |||||
通过向命令行打印进度的方式显示 | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | ||||
self.print_every = print_every | self.print_every = print_every | ||||
self.task2id = {} | self.task2id = {} | ||||
@@ -201,18 +222,19 @@ class RawTextCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
logger.info(text) | logger.info(text) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
else: | |||||
logger.info(results) | |||||
results = json.dumps(results) | |||||
logger.info(results) | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
@@ -221,19 +243,20 @@ class RawTextCallback(ProgressCallback): | |||||
class TqdmCallback(ProgressCallback): | class TqdmCallback(ProgressCallback): | ||||
""" | """ | ||||
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
在训练过程中打印 *tqdm* progress bar 的 callback 。在 Trainer 中,如果设置了 ``progress_bar='tqdm'`` 就会使用 | |||||
这个 callback 来显示进度。如果需要定制这个 Callback 的参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在 | |||||
打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | ||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | ||||
* 为 ``Callable`` | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | 的 ``monitor`` 值请返回 ``None`` 。 | ||||
:param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
@@ -266,6 +289,7 @@ class TqdmCallback(ProgressCallback): | |||||
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
super(TqdmCallback, self).on_train_end(trainer) | |||||
self.clear_tasks() | self.clear_tasks() | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
@@ -287,18 +311,19 @@ class TqdmCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer, results) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
logger.info(text) | logger.info(text) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
else: | |||||
logger.info(results) | |||||
results = json.dumps(results) | |||||
logger.info(results) | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
@@ -0,0 +1,152 @@ | |||||
import time | |||||
from .callback import Callback | |||||
from ..log import logger | |||||
__all__ = ['TimerCallback'] | |||||
class _Timer: | |||||
"""Timer.""" | |||||
def __init__(self, name): | |||||
self.name_ = name | |||||
self.elapsed_ = 0.0 | |||||
self.started_ = False | |||||
self.start_time = time.time() | |||||
def start(self): | |||||
"""Start the timer.""" | |||||
assert not self.started_, f'{self.name_} timer has already been started' | |||||
self.start_time = time.time() | |||||
self.started_ = True | |||||
def stop(self): | |||||
"""Stop the timer.""" | |||||
assert self.started_, f'{self.name_} timer is not started' | |||||
self.elapsed_ += (time.time() - self.start_time) | |||||
self.started_ = False | |||||
def reset(self): | |||||
"""Reset timer.""" | |||||
self.elapsed_ = 0.0 | |||||
self.started_ = False | |||||
def elapsed(self, reset=True): | |||||
"""Calculate the elapsed time.""" | |||||
started_ = self.started_ | |||||
# If the timing in progress, end it first. | |||||
if self.started_: | |||||
self.stop() | |||||
# Get the elapsed time. | |||||
elapsed_ = self.elapsed_ | |||||
# Reset the elapsed time | |||||
if reset: | |||||
self.reset() | |||||
# If timing was in progress, set it back. | |||||
if started_: | |||||
self.start() | |||||
return elapsed_ | |||||
class Timers: | |||||
"""Group of timers.""" | |||||
def __init__(self): | |||||
self.timers = {} | |||||
def __call__(self, name): | |||||
if name not in self.timers: | |||||
self.timers[name] = _Timer(name) | |||||
return self.timers[name] | |||||
def __contains__(self, item): | |||||
return item in self.timers | |||||
def reset(self): | |||||
for timer in self.timers.values(): | |||||
timer.reset() | |||||
class TimerCallback(Callback): | |||||
""" | |||||
这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长、评测时长、总时长等 | |||||
""" | |||||
def __init__(self, print_every=-1, time_ndigit=3): | |||||
""" | |||||
:param print_every: 在哪个时候打印时间信息。 | |||||
* *负数*: 表示每隔多少 epoch 结束打印一次; | |||||
* *0*: 表示整个训练结束才打印; | |||||
* *正数*: 每隔多少个 step 打印一次; | |||||
:param time_ndigit: 保留多少位的小数 | |||||
""" | |||||
assert isinstance(print_every, int), "print_every must be an int number." | |||||
self.timers = Timers() | |||||
self.print_every = print_every | |||||
self.time_ndigit = time_ndigit | |||||
def on_train_begin(self, trainer): | |||||
self.timers('total').start() | |||||
self.timers('train').start() | |||||
def on_fetch_data_begin(self, trainer): | |||||
self.timers('fetch-data').start() | |||||
def on_fetch_data_end(self, trainer): | |||||
self.timers('fetch-data').stop() | |||||
def on_train_batch_begin(self, trainer, batch, indices): | |||||
self.timers('forward').start() | |||||
def on_before_backward(self, trainer, outputs): | |||||
self.timers('forward').stop() | |||||
self.timers('backward').start() | |||||
def on_after_backward(self, trainer): | |||||
self.timers('backward').stop() | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
self.timers('optimize').start() | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
self.timers('optimize').stop() | |||||
def on_evaluate_begin(self, trainer): | |||||
self.timers('train').stop() | |||||
self.timers('evaluate').start() | |||||
def on_evaluate_end(self, trainer, results): | |||||
self.timers('evaluate').stop() | |||||
self.timers('train').start() | |||||
def format_timer(self, reset=True): | |||||
line = '' | |||||
timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] | |||||
for timer_name in timers: | |||||
if not timer_name in self.timers: | |||||
continue | |||||
timer = self.timers(timer_name) | |||||
elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) | |||||
if elapsed != 0: | |||||
line = line + f', {timer_name}: {elapsed}s' | |||||
return line | |||||
def on_train_batch_end(self, trainer): | |||||
if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Running {self.print_every} batches{line}") | |||||
def on_train_epoch_end(self, trainer): | |||||
if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Running {abs(self.print_every)} epochs{line}") | |||||
def on_train_end(self, trainer): | |||||
if self.print_every == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Training finished{line}") | |||||
@@ -24,8 +24,8 @@ class Saver: | |||||
- folder_name # 由 save() 调用时传入。 | - folder_name # 由 save() 调用时传入。 | ||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | ||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | ||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | :param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | ||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
@@ -178,28 +178,28 @@ class TopkSaver(ResultsMonitor, Saver): | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | ||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | ||||
:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | |||||
:param monitor: 监控的 metric 值。 | |||||
:param topk: 保存表现最好的 ``topk`` 个模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 ``topk`` 个; | |||||
:param monitor: 监控的 metric 值: | |||||
* 为 ``None`` | * 为 ``None`` | ||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | * 为 ``str`` | ||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | ||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | ||||
* 为 ``Callable`` | |||||
* 为 :class:`Callable` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | ||||
的 ``monitor`` 值请返回 ``None`` 。 | 的 ``monitor`` 值请返回 ``None`` 。 | ||||
:param larger_better: 该 monitor 是否越大越好。 | :param larger_better: 该 monitor 是否越大越好。 | ||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | ||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
:param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | ||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||||
:param only_state_dict: 保存时是否仅保存权重,在 ``model_save_fn`` 不为 None 时无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | ||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 ``topk`` 的场景下有用,默认为 True 。 | |||||
:param kwargs: 更多需要传递给 :meth:`Trainer.save_checkpoint` 或者 :meth:`Trainer.save_model` 接口的参数。 | |||||
""" | """ | ||||
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | ||||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | ||||
@@ -220,7 +220,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||||
@rank_zero_call | @rank_zero_call | ||||
def save_topk(self, trainer, results: Dict) -> Optional[str]: | def save_topk(self, trainer, results: Dict) -> Optional[str]: | ||||
""" | """ | ||||
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足 | |||||
根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足 | |||||
topk 要求,没有发生保存。 | topk 要求,没有发生保存。 | ||||
:param trainer: | :param trainer: | ||||
@@ -11,17 +11,17 @@ if _NEED_IMPORT_FAIRSCALE: | |||||
class TorchGradClipCallback(Callback): | class TorchGradClipCallback(Callback): | ||||
r""" | r""" | ||||
在每次 optimizer update 之前将 parameter 进行 clip 。 | |||||
在每次 :func:`optimizer.step` 之前对参数的梯度进行截断。 | |||||
:param clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param clip_type: 支持'norm', 'value'两种: | |||||
:param clip_value: 将梯度限制到 [-clip_value, clip_value] 之间。``clip_value`` 应该为正数; | |||||
:param clip_type: 应为 ``'norm'``, ``'value'`` 中的一个: | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], | |||||
小于-clip_value的gradient被赋值为-clip_value;大于clip_value的gradient被赋值为clip_value. | |||||
1. 为 ``'norm'`` 时, 将梯度的范数限制在 [-clip_value, clip_value] 之间; | |||||
2. 为 ``'value'`` 时,, 将梯度限制在 [-clip_value, clip_value] 之间,小于 ``-clip_value`` | |||||
的梯度被赋值为 ``-clip_value``,大于 ``clip_value`` 的梯度被赋值为 ``clip_value``; | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | |||||
:param parameters: 参数,一般通过 :func:`model.parameters` 获得。 | |||||
如果为 ``None`` 则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | |||||
""" | """ | ||||
def __init__(self, clip_value:int=1, clip_type:str='norm', | def __init__(self, clip_value:int=1, clip_type:str='norm', | ||||
parameters:Union["torch.Tensor", List["torch.Tensor"]]=None): | parameters:Union["torch.Tensor", List["torch.Tensor"]]=None): | ||||
@@ -9,14 +9,14 @@ from ..callback import Callback | |||||
class TorchWarmupCallback(Callback): | class TorchWarmupCallback(Callback): | ||||
r""" | r""" | ||||
调整 learning rate 的 callback 。 | |||||
调整学习率的 **callback** 。 | |||||
:param warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||||
:param schedule: 以哪种方式调整。 | |||||
:param warmup: 如果 ``warmup`` 为整数,则在该 step 之前,学习率根据 ``schedule`` 的策略变化; 如果 ``warmup`` 为 ``float``, | |||||
如 0.1, 则前 10% 的 step 是按照 ``schedule`` 策略调整。 | |||||
:param schedule: 对学习率进行调整的策略: | |||||
1. linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0; | |||||
2. constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||||
1. *linear* -- 前 ``warmup`` 的 step 上升到指定的学习率(从 Trainer 中 optimizer 处获取), 在剩下的 step 中下降到 0; | |||||
2. *constant* -- 前 ``warmup`` 的 step 上升到指定的学习率,余下的 step 保持不变。 | |||||
""" | """ | ||||
def __init__(self, warmup:Union[int, float]=0.1, schedule:str='constant'): | def __init__(self, warmup:Union[int, float]=0.1, schedule:str='constant'): | ||||
super().__init__() | super().__init__() | ||||
@@ -18,7 +18,7 @@ from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, Mappi | |||||
NestedMappingPackerUnpacker | NestedMappingPackerUnpacker | ||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | ||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'oneflow', 'numpy', 'raw', 'auto', None] | |||||
# 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 | # 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 | ||||
AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} | AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} | ||||
@@ -85,27 +85,33 @@ def _get_backend() -> str: | |||||
class Collator: | class Collator: | ||||
""" | """ | ||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | ||||
哦安定一个 field 是否可以 pad 的方式为:(1)当前这个 field 是否所有对象都是一样的数据类型;(因此,如果某 field 的数据有些是float | |||||
有些是 int 将知道该 field 被判定为不可 pad 类型。)(2)当前这个 field 是否每个 sample 都具有一样的深度;(因此,例如有个 field 的 | |||||
数据转为 batch 类型后为 [1, [1,2]], 会被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同)(3)当前这个 field 的类 | |||||
型是否是可以 pad (例如 str 类型的数据)。可以通过设置 logger.setLevel('debug') 来打印是判定不可 pad 的原因。 | |||||
判定一个 field 是否可以 pad 的方式为: | |||||
1. 当前这个 field 是否所有对象都是一样的数据类型;比如,如果某 field 的数据有些是 float ,有些是 int ,则该 field 将被 | |||||
判定为不可 pad 类型; | |||||
2. 当前这个 field 是否每个 sample 都具有一样的深度;比如,如果某 field 的数据转为 batch 类型后为 ``[1, [1,2]]``, 则会 | |||||
被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同; | |||||
3. 当前这个 field 的类型是否是可以 pad (例如 str 类型的数据)。可以通过设置 ``logger.setLevel('debug')`` 来打印是判定不可 | |||||
pad 的原因。 | |||||
.. note:: | .. note:: | ||||
``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个``field``应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` | |||||
的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 ``set_pad()`` 函数手动设置一下。 | |||||
``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个 ``field`` 应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` | |||||
的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 :meth:`set_pad` 函数手动设置一下。 | |||||
todo 补充 code example 。 | |||||
.. todo:: | |||||
补充 code example 。 | |||||
如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 pad_val 设置为 None 实现。 | |||||
如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 ``pad_val`` 设置为 ``None`` 实现。 | |||||
如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。 | 如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。 | ||||
Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | ||||
的 Padder 给对应的 field 。 | 的 Padder 给对应的 field 。 | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``。 | |||||
若为 ``'auto'`` ,则在进行 pad 的时候会根据调用的环境决定其 ``backend`` 。该参数对不能进行 pad 的数据没有影响,无法 pad 的数据返回一定 | |||||
是 :class:`list` 。 | |||||
""" | """ | ||||
def __init__(self, backend='auto'): | def __init__(self, backend='auto'): | ||||
self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
@@ -192,20 +198,20 @@ class Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回 Collator 自身; | |||||
""" | """ | ||||
self._renew() | self._renew() | ||||
@@ -275,8 +281,8 @@ class Collator: | |||||
""" | """ | ||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``, | |||||
若为 ``'auto'`` ,则在进行 pad 的时候会自动根据调用的环境决定其 ``backend`` ; | |||||
:return: | :return: | ||||
""" | """ | ||||
assert backend in SUPPORTED_BACKENDS | assert backend in SUPPORTED_BACKENDS | ||||
@@ -285,14 +291,14 @@ class Collator: | |||||
def set_ignore(self, *field_names) -> "Collator": | def set_ignore(self, *field_names) -> "Collator": | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||||
>>> collator = Collator().set_ignore('field1', 'field2') | >>> collator = Collator().set_ignore('field1', 'field2') | ||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回 Collator 自身; | |||||
""" | """ | ||||
self._renew() | self._renew() | ||||
input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) | input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) | ||||
@@ -2,6 +2,7 @@ from collections import defaultdict | |||||
from functools import reduce | from functools import reduce | ||||
from typing import Sequence, Mapping, Dict | from typing import Sequence, Mapping, Dict | ||||
__all__ = [] | |||||
class MappingPackerUnpacker: | class MappingPackerUnpacker: | ||||
@staticmethod | @staticmethod | ||||
@@ -70,7 +71,7 @@ class SequencePackerUnpacker: | |||||
@staticmethod | @staticmethod | ||||
def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: | def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: | ||||
""" | """ | ||||
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||||
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [2, 2]} | |||||
:param batch: 需要 unpack 的 batch 数据。 | :param batch: 需要 unpack 的 batch 数据。 | ||||
:param ignore_fields: 需要忽略的 field 。 | :param ignore_fields: 需要忽略的 field 。 | ||||
@@ -10,18 +10,19 @@ from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPad | |||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | ||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | ||||
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | ||||
from .oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder | |||||
from .exceptions import * | from .exceptions import * | ||||
def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: | def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: | ||||
""" | """ | ||||
根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 | |||||
根据 参数 与 ``batch_field`` ,返回适合于当前 ``batch_field`` 的 *padder* 。 | |||||
:param batch_field: 将某 field 的内容组合成一个 batch 传入。 | |||||
:param pad_val: | |||||
:param batch_field: 将某 field 的内容组合成一个 batch 传入; | |||||
:param pad_val: | |||||
:param backend: | :param backend: | ||||
:param dtype: | :param dtype: | ||||
:param field_name: 方便报错的。 | |||||
:param field_name: field 名称,方便在报错时显示; | |||||
:return: | :return: | ||||
""" | """ | ||||
try: | try: | ||||
@@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | ||||
@@ -105,6 +108,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | ||||
@@ -121,6 +126,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | ||||
@@ -84,14 +84,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class JittorNumberPadder(Padder): | class JittorNumberPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3]) | |||||
""" | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``jittor.Var([1, 2, 3])`` | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -106,23 +106,23 @@ class JittorNumberPadder(Padder): | |||||
class JittorSequencePadder(Padder): | class JittorSequencePadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
""" | |||||
可以将形如 ``[[1], [1, 2]]`` 这类的数据转为 ``jittor.Var([[1], [1, 2]])`` | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | """ | ||||
:param batch_field 输入的某个 field 的 batch 数据。 | |||||
:param pad_val 需要填充的值 | |||||
:dtype 数据的类型 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | """ | ||||
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
@@ -131,11 +131,11 @@ class JittorSequencePadder(Padder): | |||||
class JittorTensorPadder(Padder): | class JittorTensorPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | """ | ||||
目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。 | |||||
目前支持 ``[jittor.Var([3, 2], jittor.Var([1])]`` 类似的输入。若内部元素不为 :class:`jittor.Var` ,则必须含有 :meth:`tolist` 方法。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
:param pad_val: 需要 pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -143,11 +143,11 @@ class JittorTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | """ | ||||
将 batch_field 数据 转为 jittor.Var 并 pad 到相同长度。 | |||||
将 ``batch_field`` 数据 转为 :class:`jittor.Var` 并 pad 到相同长度。 | |||||
:param batch_field 输入的某个 field 的 batch 数据。 | |||||
:param pad_val 需要填充的值 | |||||
:dtype 数据的类型 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | """ | ||||
try: | try: | ||||
if not isinstance(batch_field[0], jittor.Var): | if not isinstance(batch_field[0], jittor.Var): | ||||
@@ -18,9 +18,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -38,15 +38,15 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class NumpyNumberPadder(Padder): | class NumpyNumberPadder(Padder): | ||||
""" | """ | ||||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) 。可以通过: | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``np.array([1, 2, 3])`` 。可以通过:: | |||||
>>> NumpyNumberPadder.pad([1, 2, 3]) | >>> NumpyNumberPadder.pad([1, 2, 3]) | ||||
使用。 | 使用。 | ||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
:param pad_val: 该值无意义; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||||
:param dtype: 输出的数据的 dtype ; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
@@ -54,21 +54,28 @@ class NumpyNumberPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
return np.array(batch_field, dtype=dtype) | return np.array(batch_field, dtype=dtype) | ||||
class NumpySequencePadder(Padder): | class NumpySequencePadder(Padder): | ||||
""" | """ | ||||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``np.array([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 | |||||
可以通过以下的方式直接使用: | 可以通过以下的方式直接使用: | ||||
>>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float) | >>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float) | ||||
[[ 1. -100.] | [[ 1. -100.] | ||||
[ 1. 2.]] | [ 1. 2.]] | ||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
:param pad_val: pad 的值是多少; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||||
:param dtype: 输出的数据的 dtype ; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
@@ -76,18 +83,25 @@ class NumpySequencePadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | ||||
class NumpyTensorPadder(Padder): | class NumpyTensorPadder(Padder): | ||||
""" | """ | ||||
pad 类似于 [np.array([3, 4]), np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||||
pad 类似于 ``[np.array([3, 4]), np.array([1])]`` 的 field 。若内部元素不为 :class:`np.ndarray` ,则必须含有 :meth:`tolist` 方法。 | |||||
>>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | >>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | ||||
[[ 3. 4.] | [[ 3. 4.] | ||||
[ 1. -100.]] | [ 1. -100.]] | ||||
:param pad_val: pad 的值是多少。 | :param pad_val: pad 的值是多少。 | ||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | :param dtype: 输出的数据的 dtype 是什么 | ||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
@@ -96,6 +110,13 @@ class NumpyTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
try: | try: | ||||
if not isinstance(batch_field[0], np.ndarray): | if not isinstance(batch_field[0], np.ndarray): | ||||
batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | ||||
@@ -0,0 +1,225 @@ | |||||
__all__ = [ | |||||
'OneflowNumberPadder', | |||||
'OneflowSequencePadder', | |||||
'OneflowTensorPadder' | |||||
] | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
numpy_to_oneflow_dtype_dict = { | |||||
np.bool_: oneflow.bool, | |||||
np.uint8: oneflow.uint8, | |||||
np.int8: oneflow.int8, | |||||
np.int32: oneflow.int32, | |||||
np.int64: oneflow.int64, | |||||
np.float16: oneflow.float16, | |||||
np.float32: oneflow.float32, | |||||
np.float64: oneflow.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
} | |||||
number_to_oneflow_dtype_dict = { | |||||
float: oneflow.float32, # 因为 oneflow.tensor([1], dtype=float)是oneflow.float64 | |||||
int: oneflow.int64, | |||||
bool: oneflow.bool | |||||
} | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_oneflow_tensor(dtype): | |||||
""" | |||||
判断是否为 oneflow 的 tensor | |||||
:param dtype 数据的 dtype 类型 | |||||
""" | |||||
if not isclass(dtype) and isinstance(dtype, oneflow.dtype): | |||||
return True | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | |||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_oneflow_tensor(ele_dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or oneflow.Tensor but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_oneflow_tensor(dtype) or is_number(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or oneflow.dtype but get `{dtype}`.") | |||||
dtype = number_to_oneflow_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
if ele_dtype is not None: | |||||
if (is_number(ele_dtype) or is_oneflow_tensor(ele_dtype)): | |||||
ele_dtype = number_to_oneflow_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype) | |||||
return dtype | |||||
class OneflowNumberPadder(Padder): | |||||
""" | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``oneflow.Tensor([1, 2, 3])``。 | |||||
:param pad_val: 该值无意义; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
return oneflow.tensor(batch_field, dtype=dtype) | |||||
class OneflowSequencePadder(Padder): | |||||
""" | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``oneflow.Tensor([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; | |||||
:param type: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
tensor = get_padded_oneflow_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class OneflowTensorPadder(Padder): | |||||
""" | |||||
目前支持 ``[oneflow.tensor([3, 2], oneflow.tensor([1])]`` 类似的输入,若内部元素不为 :class:`oneflow.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||||
>>> OneflowTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||||
[[ 3. 4.] | |||||
[ 1. -100.]] | |||||
>>> OneflowTensorPadder.pad([oneflow.LongTensor([3, 4]), oneflow.LongTensor([1])], pad_val=-100) | |||||
tensor([[ 3, 4], | |||||
[ 1, -100]]) | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型。 | |||||
:param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
device = None | |||||
try: | |||||
if not isinstance(batch_field[0], oneflow.Tensor): | |||||
batch_field = [oneflow.tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
else: | |||||
batch_field = [field.to(dtype) for field in batch_field] | |||||
device = batch_field[0].device | |||||
if dtype is None: | |||||
dtype = batch_field[0].dtype | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a oneflow.Tensor (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | |||||
if len(batch_field) < 2: | |||||
max_shape = [len(batch_field)] + list(shapes[0]) | |||||
else: | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = oneflow.tensor(content_i, dtype=dtype) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = oneflow.tensor(content_ii, dtype=dtype) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = oneflow.tensor(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = oneflow.tensor(content_iii, dtype=dtype) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = oneflow.tensor(batch_field, dtype=dtype) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> oneflow.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -1,7 +1,7 @@ | |||||
class Padder: | class Padder: | ||||
""" | """ | ||||
所有 Padder 对象父类,所有的 Padder 对象都会实现 pad(batch_field, pad_val=0, dtype=None) 的静态函数。 | |||||
所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 ``pad(batch_field, pad_val=0, dtype=None)`` 。 | |||||
""" | """ | ||||
def __init__(self, pad_val, dtype): | def __init__(self, pad_val, dtype): | ||||
@@ -99,11 +99,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class PaddleNumberPadder(Padder): | class PaddleNumberPadder(Padder): | ||||
""" | """ | ||||
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``paddle.Tensor([1, 2, 3])`` | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
:param pad_val: 该值无意义; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | ||||
@@ -112,16 +112,23 @@ class PaddleNumberPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
return paddle.to_tensor(batch_field, dtype=dtype) | return paddle.to_tensor(batch_field, dtype=dtype) | ||||
class PaddleSequencePadder(Padder): | class PaddleSequencePadder(Padder): | ||||
""" | """ | ||||
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``paddle.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值。 | :param pad_val: pad 的值。 | ||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||||
""" | """ | ||||
def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
@@ -129,17 +136,30 @@ class PaddleSequencePadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
class PaddleTensorPadder(Padder): | class PaddleTensorPadder(Padder): | ||||
""" | """ | ||||
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||||
目前支持 ``[paddle.tensor([3, 2], paddle.tensor([2, 1])]`` 类似的输入,若内部元素不为 :class:`paddle.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||||
>>> PaddleTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||||
[[ 3. 4.] | |||||
[ 1. -100.]] | |||||
>>> PaddleTensorPadder.pad([paddle.to_tensor([3, 4]), paddle.to_tensor([1])], pad_val=-100) | |||||
tensor([[ 3, 4], | |||||
[ 1, -100]]) | |||||
:param pad_val: pad 的值。 | :param pad_val: pad 的值。 | ||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
@@ -147,6 +167,13 @@ class PaddleTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 数据的类型 | |||||
""" | |||||
try: | try: | ||||
if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
batch_field = [np.array(field.tolist()) for field in batch_field] | batch_field = [np.array(field.tolist()) for field in batch_field] | ||||
@@ -13,9 +13,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -34,11 +34,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class RawNumberPadder(Padder): | class RawNumberPadder(Padder): | ||||
""" | """ | ||||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``[1, 2, 3]`` 。实际上该 padder 无意义。 | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
:param pad_val: | |||||
:param ele_dtype: | |||||
:param dtype: | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
@@ -54,11 +54,11 @@ class RawNumberPadder(Padder): | |||||
class RawSequencePadder(Padder): | class RawSequencePadder(Padder): | ||||
""" | """ | ||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
:param pad_val: pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||||
:param dtype: 输出的数据的 dtype ; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
@@ -68,8 +68,8 @@ class RawSequencePadder(Padder): | |||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | """ | ||||
:param batch_field: | |||||
:param pad_val: | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 该参数无意义。 | :param dtype: 该参数无意义。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -78,11 +78,11 @@ class RawSequencePadder(Padder): | |||||
class RawTensorPadder(Padder): | class RawTensorPadder(Padder): | ||||
""" | """ | ||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
:param pad_val: pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||||
:param dtype: 输出的数据的 dtype ; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
@@ -92,8 +92,8 @@ class RawTensorPadder(Padder): | |||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | """ | ||||
:param batch_field: | |||||
:param pad_val: | |||||
:param batch_field: 输入的某个 field 的 batch 数据。 | |||||
:param pad_val: 需要填充的值 | |||||
:param dtype: 该参数无意义。 | :param dtype: 该参数无意义。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -38,7 +38,7 @@ def is_torch_tensor(dtype): | |||||
""" | """ | ||||
判断是否为 torch 的 tensor | 判断是否为 torch 的 tensor | ||||
:param dtype 数据的 dtype 类型 | |||||
:param dtype: 数据的 dtype 类型 | |||||
""" | """ | ||||
if not isclass(dtype) and isinstance(dtype, torch.dtype): | if not isclass(dtype) and isinstance(dtype, torch.dtype): | ||||
return True | return True | ||||
@@ -49,9 +49,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -77,11 +77,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class TorchNumberPadder(Padder): | class TorchNumberPadder(Padder): | ||||
""" | """ | ||||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||||
可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``torch.Tensor([1, 2, 3])`` | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
:param pad_val: 该值无意义; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
@@ -94,11 +94,11 @@ class TorchNumberPadder(Padder): | |||||
class TorchSequencePadder(Padder): | class TorchSequencePadder(Padder): | ||||
""" | """ | ||||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``torch.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
:param pad_val: 需要 pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
@@ -112,7 +112,7 @@ class TorchSequencePadder(Padder): | |||||
class TorchTensorPadder(Padder): | class TorchTensorPadder(Padder): | ||||
""" | """ | ||||
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||||
目前支持 ``[torch.tensor([3, 2], torch.tensor([1])]`` 类似的输入。若内部元素不为 :class:`torch.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||||
>>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | >>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | ||||
[[ 3. 4.] | [[ 3. 4.] | ||||
@@ -121,9 +121,9 @@ class TorchTensorPadder(Padder): | |||||
tensor([[ 3, 4], | tensor([[ 3, 4], | ||||
[ 1, -100]]) | [ 1, -100]]) | ||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
:param pad_val: 需要 pad 的值; | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||||
:param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||||
""" | """ | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
@@ -5,6 +5,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
__all__ = [] | |||||
def is_torch_tensor_dtype(dtype) -> bool: | def is_torch_tensor_dtype(dtype) -> bool: | ||||
""" | """ | ||||
@@ -78,13 +78,12 @@ def fill_array(batch_field:List, padded_batch:np.ndarray): | |||||
def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: | def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: | ||||
""" | """ | ||||
例如: | |||||
[[1,2], [3]] -> np.array([[1, 2], [3, 0]]) | |||||
将输入 pad 为 :class:`numpy.arraay` 类型,如:``[[1,2], [3]] -> np.array([[1, 2], [3, 0]])`` | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d** (多为句子长度)/ **2d** (多为文本序列)/ **3d** (多为字符序列) | |||||
/4d(多为图片); | |||||
:param dtype: 输出数据的 dtype 类型; | |||||
:param pad_val: 填充值; | |||||
:return: | :return: | ||||
""" | """ | ||||
shapes = get_shape(batch_field) | shapes = get_shape(batch_field) | ||||
@@ -1,5 +1,5 @@ | |||||
r""" | r""" | ||||
``Evaluator`` 是新版 fastNLP 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 fastNLP 中**训练**和**评测**的框架。 | |||||
``Evaluator`` 是新版 **fastNLP** 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 **fastNLP** 中 **训练** 和 **评测** 的框架。 | |||||
``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地 | ``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地 | ||||
对您的模型进行全方位地评测。 | 对您的模型进行全方位地评测。 | ||||
@@ -75,11 +75,11 @@ class Evaluator: | |||||
:param device: 等价于 ``Trainer`` 中的 ``device`` 参数; | :param device: 等价于 ``Trainer`` 中的 ``device`` 参数; | ||||
:param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | :param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | ||||
不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`; | 不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`; | ||||
:param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 ``model.evaluate_step`` 还是 | |||||
``model.forward``(不同训练框架所使用的的前向传播函数的方法名称不同); | |||||
:param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 :meth:`model.evaluate_step` 还是 | |||||
:meth:`model.forward` (不同训练框架所使用的的前向传播函数的方法名称不同); | |||||
1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数; | 1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数; | ||||
2. 如果为 ``str`` 类型,例如为 ``my_evaluate_step_fn``,则尝试寻找 ``model.my_evaluate_step_fn``,如果找不到则直接报错; | |||||
2. 如果为 ``str`` 类型,例如为 ``'my_evaluate_step_fn'``,则尝试寻找 :meth:`model.my_evaluate_step_fn`,如果找不到则直接报错; | |||||
:param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对 | :param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对 | ||||
``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制; | ``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制; | ||||
@@ -97,18 +97,28 @@ class Evaluator: | |||||
``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去; | ``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去; | ||||
:param fp16: 是否在评测时使用 fp16; | |||||
:param fp16: 是否在评测时使用 fp16 混合精度; | |||||
:param verbose: 是否打印 evaluate 的结果; | :param verbose: 是否打印 evaluate 的结果; | ||||
:kwargs: | :kwargs: | ||||
* *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数; | * *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数; | ||||
* *paddle_kwargs* -- 等价于 ``Trainer`` 中的 ``paddle_kwargs`` 参数; | |||||
* *fairscale_kwargs* -- 等价于 ``Trainer`` 中的 ``fairscale_kwargs`` 参数; | |||||
* *deepspeed_kwargs* -- 等价于 ``Trainer`` 中的 ``deepspeed_kwargs`` 参数; | |||||
* *oneflow_kwargs* -- 等价于 ``Trainer`` 中的 ``oneflow_kwargs`` 参数; | |||||
* *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数; | * *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数; | ||||
* *model_use_eval_mode* (``bool``) -- | * *model_use_eval_mode* (``bool``) -- | ||||
是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 | |||||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||||
是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 | |||||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||||
* *use_dist_sampler* -- | * *use_dist_sampler* -- | ||||
是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||||
表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||||
分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||||
的 sampler 为: | |||||
- 深度学习框架自带的默认 sampler ; | |||||
- fastNLP 的 Sampler ; | |||||
则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||||
用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。 | |||||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | ||||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | ||||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 | * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 | ||||
@@ -119,8 +129,8 @@ class Evaluator: | |||||
_evaluate_batch_loop: Loop | _evaluate_batch_loop: Loop | ||||
def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, | def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, | ||||
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | |||||
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | |||||
driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None, | |||||
evaluate_batch_step_fn: Optional[Callable] = None, evaluate_fn: Optional[str] = None, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | ||||
fp16: bool = False, verbose: int = 1, **kwargs): | fp16: bool = False, verbose: int = 1, **kwargs): | ||||
@@ -200,16 +210,16 @@ class Evaluator: | |||||
""" | """ | ||||
用于帮助您加载模型的辅助函数; | 用于帮助您加载模型的辅助函数; | ||||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||||
直接将该 folder 传递到 model_load_fn 中; | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; | |||||
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` 不为空时, | |||||
直接将该 folder 传递到 ``model_load_fn`` 中; | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; | |||||
:param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||||
:param kwargs: 理论上您不需要使用到该参数; | :param kwargs: 理论上您不需要使用到该参数; | ||||
.. note:: | .. note:: | ||||
注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 | 注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 | ||||
训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; | |||||
训练框架的,例如都是 **pytorch** 或者 **PaddlePaddle** ; | |||||
""" | """ | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | if not isinstance(folder, (io.BytesIO, BinaryIO)): | ||||
@@ -237,15 +247,14 @@ class Evaluator: | |||||
""" | """ | ||||
该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | 该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | ||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||||
返回一个字典类型的数据,其中 key 为 metric 的名字,value 为对应 metric 的结果。 | |||||
1. 如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
``metric_indicator_name#metric_name`` | |||||
1. 如果存在多个 metric ,一个 dataloader 的情况,key 的命名规则是 | |||||
``metric_indicator_name#metric_name``; | |||||
2. 如果存在多个数据集,一个metric的情况,key的命名规则是 | 2. 如果存在多个数据集,一个metric的情况,key的命名规则是 | ||||
``metric_indicator_name#metric_name#dataloader_name`` (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
``metric_indicator_name#metric_name#dataloader_name`` | |||||
其中 metric_indicator_name 可能不存在; | |||||
``metric_indicator_name#metric_name#dataloader_name`` (其中 **#** 是默认的 separator ,可以通过 Evaluator 初始化参数修改); | |||||
3. 如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
``metric_indicator_name#metric_name#dataloader_name``,其中 metric_indicator_name 可能不存在; | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。 | :param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。 | ||||
:return: 返回评测得到的结果,是一个没有嵌套的字典; | :return: 返回评测得到的结果,是一个没有嵌套的字典; | ||||
@@ -276,8 +285,9 @@ class Evaluator: | |||||
raise e | raise e | ||||
finally: | finally: | ||||
self.finally_progress_bar() | self.finally_progress_bar() | ||||
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 | if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 | ||||
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
# metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
if self.verbose: | if self.verbose: | ||||
if self.progress_bar == 'rich': | if self.progress_bar == 'rich': | ||||
f_rich_progress.print(metric_results) | f_rich_progress.print(metric_results) | ||||
@@ -356,7 +366,7 @@ class Evaluator: | |||||
def reset(self): | def reset(self): | ||||
""" | """ | ||||
调用所有 metric 的 reset() 方法,清除累积的状态。 | |||||
调用所有 metric 的 :meth:`reset` 方法,清除累积的状态。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -364,7 +374,7 @@ class Evaluator: | |||||
def update(self, batch, outputs): | def update(self, batch, outputs): | ||||
""" | """ | ||||
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 | |||||
自动调用所有 metric 的 :meth:`update` 方法,会根据不同 metric 的参数列表进行匹配传参。 | |||||
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 | :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 | ||||
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 | :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 | ||||
@@ -374,7 +384,7 @@ class Evaluator: | |||||
def get_metric(self) -> Dict: | def get_metric(self) -> Dict: | ||||
""" | """ | ||||
调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 | |||||
调用所有 metric 的 :meth:`get_metric` 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -383,11 +393,9 @@ class Evaluator: | |||||
@property | @property | ||||
def metrics_wrapper(self): | def metrics_wrapper(self): | ||||
""" | """ | ||||
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 | |||||
由于需要保持 Evaluator 中 ``metrics`` 对象与用户传入的 ``metrics`` 保持完全一致(方便在 ``evaluate_batch_step_fn`` )中使用,同时也为了支持 | |||||
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | ||||
进行操作。 | 进行操作。 | ||||
Returns: | |||||
""" | """ | ||||
if self._metric_wrapper is None: | if self._metric_wrapper is None: | ||||
self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self) | self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self) | ||||
@@ -395,11 +403,12 @@ class Evaluator: | |||||
def evaluate_step(self, batch): | def evaluate_step(self, batch): | ||||
""" | """ | ||||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 | |||||
返回。 | |||||
将 ``batch`` 传递到 model 中进行处理,根据当前 ``evaluate_fn`` 选择进行 evaluate 。会将返回结果经过 ``output_mapping`` | |||||
处理后再 | |||||
返回。 | |||||
:param batch: {evaluate_fn} 函数支持的输入类型 | |||||
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 | |||||
:param batch: ``evaluate_fn`` 函数支持的输入类型 | |||||
:return: ``evaluate_fn`` 函数的输出结果,如果有设置 ``output_mapping`` ,将是 ``output_mapping`` 之后的结果。 | |||||
""" | """ | ||||
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | ||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | outputs = match_and_substitute_params(self.output_mapping, outputs) | ||||
@@ -408,7 +417,7 @@ class Evaluator: | |||||
@property | @property | ||||
def metrics(self): | def metrics(self): | ||||
""" | """ | ||||
返回用户传入的 metrics 对象。 | |||||
返回用户传入的 ``metrics`` 对象。 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -13,7 +13,7 @@ class EvaluateBatchLoop(Loop): | |||||
r""" | r""" | ||||
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | ``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | ||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; | |||||
""" | """ | ||||
def __init__(self, batch_step_fn:Optional[Callable]=None): | def __init__(self, batch_step_fn:Optional[Callable]=None): | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
@@ -21,10 +21,10 @@ class EvaluateBatchLoop(Loop): | |||||
def run(self, evaluator, dataloader) -> Dict: | def run(self, evaluator, dataloader) -> Dict: | ||||
r""" | r""" | ||||
需要返回在传入的 dataloader 中的 evaluation 结果 | |||||
需要返回在传入的 ``dataloader`` 中的 evaluation 结果 | |||||
:param evaluator: Evaluator 对象 | |||||
:param dataloader: 当前需要进行评测的dataloader | |||||
:param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 | |||||
:param dataloader: 当前需要进行评测的 ``dataloader`` | |||||
:return: | :return: | ||||
""" | """ | ||||
iterator = iter(dataloader) | iterator = iter(dataloader) | ||||
@@ -58,10 +58,10 @@ class EvaluateBatchLoop(Loop): | |||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(evaluator, batch): | def batch_step_fn(evaluator, batch): | ||||
r""" | r""" | ||||
针对一个 batch 的数据的评测过程; | |||||
针对一个 ``batch`` 的数据的评测过程; | |||||
:param evaluator: Evaluator 对象 | |||||
:param batch: 当前需要评测的一个 batch 的数据; | |||||
:param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象 | |||||
:param batch: 当前需要评测的一个 ``batch`` 的数据; | |||||
""" | """ | ||||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | ||||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 | evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -1,5 +1,5 @@ | |||||
r""" | r""" | ||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 **fastNLP** 主要功能的同时保证 **fastNLP** 的易用性和代码的易读性,我们只对 | |||||
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | 训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | ||||
更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | 更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | ||||
@@ -24,7 +24,7 @@ class Loop: | |||||
.. note:: | .. note:: | ||||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``, ``Trainer.backward`` 等; | |||||
在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | 在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | ||||
中对于 callback 函数的调用; | 中对于 callback 函数的调用; | ||||
@@ -34,5 +34,5 @@ class Loop: | |||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | ||||
r""" | r""" | ||||
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||||
对于具体的一个 ``batch`` 的数据,实现训练或者评测过程中的一步; | |||||
""" | """ |
@@ -14,7 +14,7 @@ class TrainBatchLoop(Loop): | |||||
r""" | r""" | ||||
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | ``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | ||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``; | |||||
""" | """ | ||||
def __init__(self, batch_step_fn: Optional[Callable] = None): | def __init__(self, batch_step_fn: Optional[Callable] = None): | ||||
@@ -23,14 +23,14 @@ class TrainBatchLoop(Loop): | |||||
def run(self, trainer, dataloader): | def run(self, trainer, dataloader): | ||||
r""" | r""" | ||||
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||||
对传入的 ``dataloader`` 进行一个 epoch 的主要的训练的循环过程; | |||||
.. note:: | .. note:: | ||||
您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | 您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param dataloader: 当前训练所使用的 dataloader; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param dataloader: 当前训练所使用的 ``dataloader``; | |||||
""" | """ | ||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | ||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
@@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): | |||||
batch = next(dataloader) | batch = next(dataloader) | ||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
except StopIteration: | except StopIteration: | ||||
trainer.on_fetch_data_end() | |||||
break | break | ||||
trainer.on_fetch_data_end() | |||||
try: | try: | ||||
trainer.on_fetch_data_end() | |||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
@@ -66,10 +68,10 @@ class TrainBatchLoop(Loop): | |||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(trainer, batch): | def batch_step_fn(trainer, batch): | ||||
r""" | r""" | ||||
针对一个 batch 的数据的训练过程; | |||||
针对一个 ``batch`` 的数据的训练过程; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param batch: 一个 batch 的数据; | |||||
:param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例; | |||||
:param batch: 一个 ``batch`` 的数据; | |||||
""" | """ | ||||
outputs = trainer.train_step(batch) | outputs = trainer.train_step(batch) | ||||
trainer.backward(outputs) | trainer.backward(outputs) | ||||
@@ -1,7 +1,7 @@ | |||||
""" | """ | ||||
``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||||
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 | |||||
模型部分,而将训练层面的逻辑完全地交给 fastNLP; | |||||
``Trainer`` 是 **fastNLP** 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||||
的训练框架;新版的 **fastNLP** 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 | |||||
模型部分,而将训练层面的逻辑完全地交给 **fastNLP**; | |||||
""" | """ | ||||
from typing import Union, Optional, List, Callable, Dict, BinaryIO | from typing import Union, Optional, List, Callable, Dict, BinaryIO | ||||
@@ -35,13 +35,14 @@ from fastNLP.envs import rank_zero_call | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class Trainer(TrainerEventTrigger): | class Trainer(TrainerEventTrigger): | ||||
r""" | r""" | ||||
用于支持快速训练的训练器。 | 用于支持快速训练的训练器。 | ||||
:param model: 训练所需要的模型,例如 ``torch.nn.Module``; | |||||
:param model: 训练所需要的模型,例如 :class:`torch.nn.Module`; | |||||
.. note:: | .. note:: | ||||
@@ -54,9 +55,17 @@ class Trainer(TrainerEventTrigger): | |||||
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | ||||
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | ||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 | |||||
国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` | |||||
的设置; | |||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["auto", "torch", "paddle", "jittor", "fairscale", "deepspeed", "oneflow"]``: | |||||
1. 值为 ``"auto"`` 时,**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式; | |||||
2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | |||||
3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||||
4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||||
5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||||
6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||||
7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||||
在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||||
.. warning:: | .. warning:: | ||||
@@ -64,19 +73,28 @@ class Trainer(TrainerEventTrigger): | |||||
这意味着当您传入一个 ``Driver`` 实例时,您传入给 ``Trainer`` 的 ``model`` 参数将会被忽略;也就是说模型在训练时使用的真正的模型是 | 这意味着当您传入一个 ``Driver`` 实例时,您传入给 ``Trainer`` 的 ``model`` 参数将会被忽略;也就是说模型在训练时使用的真正的模型是 | ||||
您传入的 ``Driver`` 实例中的模型; | 您传入的 ``Driver`` 实例中的模型; | ||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | |||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 :class:`List` 或者 :class:`Dict`; | |||||
.. warning:: | |||||
当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 | |||||
用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 | |||||
你自身保证每张卡上所使用的数据是不同的。 | |||||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | ||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, | |||||
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 | |||||
数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据 | |||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, | |||||
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 | |||||
数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 | |||||
迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); | 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); | ||||
device 的可选输入如下所示: | device 的可选输入如下所示: | ||||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; | |||||
* *torch.device*: 例如 'torch.device("cuda:0")'; | |||||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | |||||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | |||||
* *str*: 例如 ``'cpu'``, ``'cuda'``, ``'cuda:0'``, ``'cuda:1'``, ``'gpu:0'`` 等; | |||||
* *torch.device*: 例如 ``torch.device("cuda:0")``; | |||||
* *oneflow.device*:例如 ``oneflow.device("cuda", 0)``; | |||||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver` 这类 | |||||
执行分布式训练的 Driver | |||||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用分布式训练的 Driver ,不管您传入的列表的长度是 1 还是其它值; | |||||
* *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None; | * *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None; | ||||
.. note:: | .. note:: | ||||
@@ -93,9 +111,9 @@ class Trainer(TrainerEventTrigger): | |||||
.. warning:: | .. warning:: | ||||
注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! | |||||
注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! | |||||
例如,当您使用:: | |||||
例如,在 pytorch 中,当您使用:: | |||||
python -m torch.distributed.launch --nproc_per_node 2 train.py | python -m torch.distributed.launch --nproc_per_node 2 train.py | ||||
@@ -112,7 +130,7 @@ class Trainer(TrainerEventTrigger): | |||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | ||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
为 None; | |||||
为 ``None``; | |||||
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | ||||
不需要要返回值;更详细的使用位置和说明请见 :meth:`~fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`; | 不需要要返回值;更详细的使用位置和说明请见 :meth:`~fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`; | ||||
:param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | :param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | ||||
@@ -124,8 +142,8 @@ class Trainer(TrainerEventTrigger): | |||||
.. note:: | .. note:: | ||||
在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示: | 在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示: | ||||
1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 ``Model.train_step``;如果没有找到,那么默认使用 ``Model.forward``; | |||||
2. 如果 ``train_fn`` 为一个字符串,例如 'my_step_fn',那么我们首先会在 model 的类 Model 中寻找方法 ``Model.my_step_fn``, | |||||
1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 :meth:`Model.train_step` ;如果没有找到,那么默认使用 :meth:`Model.forward`; | |||||
2. 如果 ``train_fn`` 为一个字符串,例如 ``'my_step_fn'``,那么我们首先会在 model 的类 Model 中寻找方法 :meth:`Model.my_step_fn`, | |||||
如果没有找到,那么会直接报错; | 如果没有找到,那么会直接报错; | ||||
:param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None`` | :param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None`` | ||||
@@ -133,7 +151,7 @@ class Trainer(TrainerEventTrigger): | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见 | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见 | ||||
:class:`~fastNLP.core.callbacks.Callback`; | :class:`~fastNLP.core.callbacks.Callback`; | ||||
:param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor, | :param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor, | ||||
例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||||
例如 ``{"acc1": AccMetric(), "acc2": AccMetric()}``; | |||||
目前我们支持的 ``metric`` 的种类有以下几种: | 目前我们支持的 ``metric`` 的种类有以下几种: | ||||
@@ -147,7 +165,7 @@ class Trainer(TrainerEventTrigger): | |||||
1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | ||||
2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | 2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | ||||
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | ||||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||||
返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||||
.. note:: | .. note:: | ||||
@@ -199,7 +217,7 @@ class Trainer(TrainerEventTrigger): | |||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | ||||
1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据**前向函数**所需要的参数从 batch 中提取对应的对象,然后传入到**前向函数**中; | |||||
1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据 **前向函数** 所需要的参数从 batch 中提取对应的对象,然后传入到 **前向函数** 中; | |||||
2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型; | 2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型; | ||||
.. todo:: | .. todo:: | ||||
@@ -210,8 +228,8 @@ class Trainer(TrainerEventTrigger): | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; | ||||
:param fp16: 是否开启混合精度训练,默认为 False; | :param fp16: 是否开启混合精度训练,默认为 False; | ||||
:param monitor: 对于一些特殊的 ``Callback``,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` | :param monitor: 对于一些特殊的 ``Callback``,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` | ||||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们 | |||||
想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: | |||||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` | |||||
而言,如果我们想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: | |||||
.. code-block:: | .. code-block:: | ||||
@@ -225,7 +243,7 @@ class Trainer(TrainerEventTrigger): | |||||
)] | )] | ||||
) | ) | ||||
这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 | |||||
这意味着对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 | |||||
``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor`` | ``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor`` | ||||
参数更详细的说明,请见 :class:`~fastNLP.core.callbacks.CheckpointCallback`; | 参数更详细的说明,请见 :class:`~fastNLP.core.callbacks.CheckpointCallback`; | ||||
@@ -237,9 +255,22 @@ class Trainer(TrainerEventTrigger): | |||||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | ||||
:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||||
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||||
:param overfit_batches: 使用该参数来支持 **'过拟合'** 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 | |||||
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; | |||||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | |||||
.. note:: | |||||
您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 | |||||
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据 | |||||
来进行训练; | |||||
.. warning:: | |||||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 ``evaluate_dataloaders`` | |||||
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; | |||||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 ``None``; | |||||
.. note:: | .. note:: | ||||
@@ -261,47 +292,52 @@ class Trainer(TrainerEventTrigger): | |||||
) | ) | ||||
另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的 | 另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的 | ||||
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 None;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` | |||||
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 ``None``;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on` | |||||
:kwargs: | :kwargs: | ||||
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | |||||
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | |||||
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | |||||
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
* non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
* gradscaler_kwargs -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数。 | |||||
* *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: | |||||
* fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | |||||
* is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; | |||||
* role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` | |||||
* 其它用于初始化 ``DataParallel`` 的参数; | |||||
* *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 | |||||
:class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; | |||||
* *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 | |||||
:class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; | |||||
* *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; | |||||
* *deepspeed_kwargs* -- ``DeepSpeedDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.DeepSpeedDriver`; | |||||
* *torch_kwargs* -- ``OneflowDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.oneflow_driver.OneflowSingleDriver` 和 | |||||
:class:`~fastNLP.core.drivers.oneflow_driver.OneflowDDPDriver`; | |||||
* *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示 | * *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示 | ||||
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; | 当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; | ||||
.. note:: | .. note:: | ||||
注意您在绝大部分情况下不会用到该参数! | |||||
**注意您在绝大部分情况下不会用到该参数!** | |||||
1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效; | 1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效; | ||||
2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, | |||||
driver 实例的 ``model_device`` 才会为 None; | |||||
3. 对于 paddle,该参数无效; | |||||
2. 对于 **pytorch**,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, | |||||
driver 实例的 ``model_device`` 才会为 None; | |||||
2. 对于 **deepspeed**,仅当用户自己通过 ``deepspeed xxx.py`` 并且自己初始化 ``model.initialize`` 时, | |||||
driver 实例的 ``model_device`` 才会为 None; | |||||
3. 对于 **paddle** 和 **oneflow**,该参数无效; | |||||
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | ||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||||
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; | |||||
内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如 | |||||
8卡训练,只有9个 sample ,如果 batch_size 为 1,那么第二个 batch 时,有7张卡将没有 sample 可用,因此只有 **重复** 使用 sample 来 pad 到第二个 | |||||
batch 中。如果不希望 fastNLP 对 dataloader 的 sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 | |||||
对 ``train_dataloader`` 做的数据做特殊处理使得其在不同的卡之间 sample 是不同的。 | |||||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||||
evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||||
的 sampler 为: | |||||
- 深度学习框架自带的默认 sampler ; | |||||
- fastNLP 的 Sampler ; | |||||
则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||||
用到的数据。 | |||||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
``["all", "ignore", "only_error"]``;当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 ``"only_error"``; | |||||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | ||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 | |||||
* *progress_bar* -- 显示进度条的方式,目前支持 ``[None, 'raw', 'rich', 'auto', 'tqdm']`` 或者 :class:`~fastNLP.RichCallback` 、 :class:`~fastNLP.RawTextCallback` 等对象, | |||||
默认为 ``'auto'`` , ``'auto'`` 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 | |||||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | ||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | ||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | ||||
@@ -312,19 +348,19 @@ class Trainer(TrainerEventTrigger): | |||||
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; | ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; | ||||
``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: | ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: | ||||
必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``; | |||||
必须的参数:``metrics`` 与 ``evaluate_dataloaders``; | |||||
可选的其它参数:1. ``evaluate_batch_step_fn;2. ``evaluate_fn``;3. ``evaluate_every``;4. ``input_mapping``; | |||||
5. ``output_mapping``; 6. ``model_wo_auto_param_call``;7. ``fp16``;8. ``monitor``;9. ``larger_better``; | |||||
可选的其它参数:``evaluate_batch_step_fn``、 ``evaluate_fn``、``evaluate_every``、``input_mapping``、 | |||||
``output_mapping``、``model_wo_auto_param_call``、``fp16``、``monitor``、``larger_better``; | |||||
.. warning:: | .. warning:: | ||||
如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为: | 如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为: | ||||
1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的 | 1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的 | ||||
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; | |||||
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; | |||||
2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的 | 2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的 | ||||
``Evaluator`` 是一致的; | |||||
``Evaluator`` 是一致的; | |||||
当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和 | 当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和 | ||||
``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。 | ``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。 | ||||
@@ -338,9 +374,9 @@ class Trainer(TrainerEventTrigger): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
model, | model, | ||||
driver, | |||||
train_dataloader, | train_dataloader, | ||||
optimizers, | optimizers, | ||||
driver: str = "auto", | |||||
device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
n_epochs: int = 20, | n_epochs: int = 20, | ||||
evaluate_dataloaders=None, | evaluate_dataloaders=None, | ||||
@@ -359,6 +395,7 @@ class Trainer(TrainerEventTrigger): | |||||
monitor: Union[str, Callable] = None, | monitor: Union[str, Callable] = None, | ||||
larger_better: bool = True, | larger_better: bool = True, | ||||
n_batches: int = -1, | n_batches: int = -1, | ||||
overfit_batches: int = 0, | |||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -456,9 +493,6 @@ class Trainer(TrainerEventTrigger): | |||||
n_batches=n_batches | n_batches=n_batches | ||||
) | ) | ||||
if metrics is None and evaluate_dataloaders is not None: | |||||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
if metrics is not None and evaluate_dataloaders is None: | if metrics is not None and evaluate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | ||||
@@ -482,33 +516,42 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
_dist_sampler = None | _dist_sampler = None | ||||
self.dataloader = self.train_dataloader | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | |||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||||
reproducible=self.callback_manager._need_reproducible_sampler) | |||||
# 进行 overfit 相关的设置; | |||||
if overfit_batches != 0: | |||||
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) | |||||
self.overfit_batches = overfit_batches | |||||
self.evaluator = None | self.evaluator = None | ||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and evaluate_dataloaders is not None: | |||||
check_evaluate_every(evaluate_every) | |||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||||
progress_bar=progress_bar, | |||||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||||
if metrics is not None: | |||||
if overfit_batches != 0: | |||||
evaluate_dataloaders = self.dataloader | |||||
if evaluate_dataloaders is not None: | |||||
check_evaluate_every(evaluate_every) | |||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||||
progress_bar=progress_bar, | |||||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||||
else: | |||||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
if train_fn is not None and not isinstance(train_fn, str): | if train_fn is not None and not isinstance(train_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | ||||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | ||||
self.train_fn = train_fn | self.train_fn = train_fn | ||||
self.dataloader = self.train_dataloader | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | |||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||||
reproducible=self.callback_manager._need_reproducible_sampler) | |||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | self.evaluate_batch_step_fn = evaluate_batch_step_fn | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
@@ -521,17 +564,17 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数; | 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数; | ||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback`` | |||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 ``resume_from`` 为 ``None``,并且使用 ``CheckpointCallback`` | |||||
去保存断点重训的文件; | 去保存断点重训的文件; | ||||
:param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度; | |||||
:param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度; | |||||
:param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 ``train_dataloader`` 本身的长度; | |||||
:param num_eval_batch_per_dl: 每个 ``evaluate_dataloader`` 验证多少个 batch 停止,*-1* 表示使用 ``evaluate_dataloader`` 本身的长度; | |||||
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测; | :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测; | ||||
:param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹; | :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹; | ||||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``, | :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``, | ||||
在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的 | 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的 | ||||
其余状态都是保持初始化时的状态不会改变; | 其余状态都是保持初始化时的状态不会改变; | ||||
:param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, | |||||
:param catch_KeyboardInterrupt: 是否捕获 :class:`KeyboardInterrupt`;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, | |||||
``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver`` | ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver`` | ||||
时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True; | 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True; | ||||
@@ -552,7 +595,7 @@ class Trainer(TrainerEventTrigger): | |||||
整体的验证流程是否正确; | 整体的验证流程是否正确; | ||||
``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用; | ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用; | ||||
并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的 | |||||
并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator`` (如果其不为 ``None``) 进行验证,此时验证的 batch 的 | |||||
数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator`` | 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator`` | ||||
进行验证时会验证的 batch 的数量。 | 进行验证时会验证的 batch 的数量。 | ||||
@@ -698,7 +741,7 @@ class Trainer(TrainerEventTrigger): | |||||
.. note:: | .. note:: | ||||
对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**, | 对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**, | ||||
**"反向传播前"**,**"整体的训练结束后"**等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; | |||||
**"反向传播前"**,**"整体的训练结束后"** 等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; | |||||
该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子; | 该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子; | ||||
@@ -988,10 +1031,11 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
用于帮助您保存模型的辅助函数; | 用于帮助您保存模型的辅助函数; | ||||
:param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则我们会在这个文件夹下保存 fastnlp_model.pkl.tar 文件; | |||||
:param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 ``state_dict``; | |||||
:param folder: 保存模型的文件夹。如果没有传入 ``model_save_fn`` 参数,则我们会在这个文件夹下保存 ``fastnlp_model.pkl.tar`` 文件; | |||||
:param only_state_dict: 仅在 ``model_save_fn`` 为空时,有效。是否只保存模型的 ``state_dict``; | |||||
:param model_save_fn: 您自己定制的用来替换该保存函数本身保存逻辑的函数,当您传入了该参数后,我们会实际调用该函数,而不会去调用 ``driver`` 的 ``save_model`` 函数; | :param model_save_fn: 您自己定制的用来替换该保存函数本身保存逻辑的函数,当您传入了该参数后,我们会实际调用该函数,而不会去调用 ``driver`` 的 ``save_model`` 函数; | ||||
:param kwargs: 理论上您不需要使用到该参数; | |||||
:kwargs: | |||||
* *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; | |||||
.. note:: | .. note:: | ||||
@@ -1030,10 +1074,10 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
用于帮助您加载模型的辅助函数; | 用于帮助您加载模型的辅助函数; | ||||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||||
直接将该 folder 传递到 model_load_fn 中; | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn 不为 None`` 时,该参数无意义; | |||||
:param model_load_fn: ``callable`` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||||
:param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` | |||||
不为空时,直接将该 folder 传递到 ``model_load_fn`` 中; | |||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义; | |||||
:param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容; | |||||
:param kwargs: 理论上您不需要使用到该参数; | :param kwargs: 理论上您不需要使用到该参数; | ||||
.. note:: | .. note:: | ||||
@@ -1073,12 +1117,13 @@ class Trainer(TrainerEventTrigger): | |||||
用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、 | 用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、 | ||||
模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。 | 模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。 | ||||
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||||
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件; | |||||
:param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重; | |||||
:param folder: 保存在哪个文件夹下,会在该文件下生成两个文件:``fastnlp_checkpoint.pkl.tar`` 与 ``fastnlp_model.pkl.tar`` 。 | |||||
如果 ``model_save_fn`` 不为空,则没有 ``fastnlp_model.pkl.tar`` 文件; | |||||
:param only_state_dict: 当 ``model_save_fn`` 为空时有效,表明是否仅保存模型的权重; | |||||
:param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义模型的保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义模型的保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | ||||
参数),不需要返回值;这意味着您可以通过该函数来自己负责模型的保存过程,而我们则会将 ``trainer`` 的状态保存好; | 参数),不需要返回值;这意味着您可以通过该函数来自己负责模型的保存过程,而我们则会将 ``trainer`` 的状态保存好; | ||||
:param kwargs: 理论上您不需要使用到该参数; | |||||
:kwargs: | |||||
* *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明; | |||||
.. note:: | .. note:: | ||||
@@ -1097,7 +1142,7 @@ class Trainer(TrainerEventTrigger): | |||||
为了支持断点重训功能,我们会在调用该函数时保存以下内容: | 为了支持断点重训功能,我们会在调用该函数时保存以下内容: | ||||
1. 各个 ``callback`` 的状态,这主要涉及到一些带有运行状态的 ``callback``; | 1. 各个 ``callback`` 的状态,这主要涉及到一些带有运行状态的 ``callback``; | ||||
2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.states.TrainerState`; | |||||
2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.state.TrainerState`; | |||||
3. 一个特殊的变量 ``num_consumed_batches``,表示在这次训练过程中总共训练了多少个 batch 的数据;您不需要关心这个变量; | 3. 一个特殊的变量 ``num_consumed_batches``,表示在这次训练过程中总共训练了多少个 batch 的数据;您不需要关心这个变量; | ||||
4. sampler 的状态,为了支持断点重训功能,我们会在 trainer 初始化的时候,将您的 ``trainer_dataloader`` 的 ``sampler`` 替换为 | 4. sampler 的状态,为了支持断点重训功能,我们会在 trainer 初始化的时候,将您的 ``trainer_dataloader`` 的 ``sampler`` 替换为 | ||||
我们专门用于断点重训功能的 ``ReproducibleSampler``,详见 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler`; | 我们专门用于断点重训功能的 ``ReproducibleSampler``,详见 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler`; | ||||
@@ -1309,6 +1354,11 @@ class Trainer(TrainerEventTrigger): | |||||
用于在使用梯度累积并且进行分布式训练时,由于在前 ``accumulation_steps - 1`` 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文 | 用于在使用梯度累积并且进行分布式训练时,由于在前 ``accumulation_steps - 1`` 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文 | ||||
环境来避免梯度的同步; | 环境来避免梯度的同步; | ||||
.. note:: | |||||
部分深度学习框架的梯度累积并不需要通过提供上下文环境实现,关于这点需要您深入了解您正在使用的框架的机制;而对于这些框架,fastNLP 会返回一个 | |||||
空的上下文环境。 | |||||
:return: 一个支持 ``no_sync`` 的 ``context``; | :return: 一个支持 ``no_sync`` 的 ``context``; | ||||
""" | """ | ||||
@@ -1394,7 +1444,7 @@ class Trainer(TrainerEventTrigger): | |||||
def model_device(self): | def model_device(self): | ||||
r""" | r""" | ||||
:return: 返回当前模型所在的设备;注意该值在当且仅当在少数情况下为 ``None``,例如当使用 ``pytorch`` 时,仅当用户自己初始化 ``init_progress_group`` 时 | :return: 返回当前模型所在的设备;注意该值在当且仅当在少数情况下为 ``None``,例如当使用 ``pytorch`` 时,仅当用户自己初始化 ``init_progress_group`` 时 | ||||
``model_device`` 才为 None; | |||||
``model_device`` 才为 None; | |||||
""" | """ | ||||
return self.driver.model_device | return self.driver.model_device | ||||
@@ -42,7 +42,7 @@ class State(dict): | |||||
class TrainerState: | class TrainerState: | ||||
r""" | r""" | ||||
该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的; | 该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的; | ||||
我们保存的state大部分上是 trainer 断点重训 需要重新加载的; | |||||
我们保存的 state 大部分上是 trainer 断点重训 需要重新加载的; | |||||
专属于 `Trainer` 的状态记载的类; | 专属于 `Trainer` 的状态记载的类; | ||||
:param n_epochs: 训练过程中总共的 epoch 的数量; | :param n_epochs: 训练过程中总共的 epoch 的数量; | ||||
@@ -50,7 +50,7 @@ class TrainerState: | |||||
:param global_forward_batches: 当前模型总共 forward 了多少个 step; | :param global_forward_batches: 当前模型总共 forward 了多少个 step; | ||||
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | ||||
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | ||||
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs; | |||||
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 ``n_batches = num_batches_per_epoch * n_epochs`` ; | |||||
""" | """ | ||||
n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
@@ -73,6 +73,7 @@ class TrainerState: | |||||
def load_state_dict(self, state_dict: Dict): | def load_state_dict(self, state_dict: Dict): | ||||
r""" | r""" | ||||
用于断点重训来重新加载保存的状态字典; | 用于断点重训来重新加载保存的状态字典; | ||||
:param state_dict: 用于加载的状态字典; | :param state_dict: 用于加载的状态字典; | ||||
""" | """ | ||||
for key in state_dict: | for key in state_dict: | ||||
@@ -4,11 +4,12 @@ from fastNLP.core.callbacks import CallbackManager | |||||
from .state import TrainerState | from .state import TrainerState | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | from fastNLP.core.utils.utils import _check_valid_parameters_number | ||||
__all__ = [] | |||||
class TrainerEventTrigger: | class TrainerEventTrigger: | ||||
r""" | r""" | ||||
为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' | |||||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 | |||||
为了避免在训练流程中调用 callback 函数中写成类似 `'trainer.callback_manager.on_train_begin'` 的形式,我们选择单独为 ``Trainer`` | |||||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 :meth:`on_validate_end` 来通知所有的 ``CheckpointCallback`` 实例在当前的 step 后保存 | |||||
模型。 | 模型。 | ||||
""" | """ | ||||
callback_manager: CallbackManager | callback_manager: CallbackManager | ||||
@@ -138,7 +139,7 @@ def check_evaluate_every(evaluate_every): | |||||
``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; | ``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; | ||||
主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性; | |||||
主要在于当参数 ``evaluate_every`` 是一个 Callable 的函数时,需要保证其参数的正确性; | |||||
""" | """ | ||||
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | ||||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
@@ -3,14 +3,20 @@ __all__ = [ | |||||
'TorchDataLoader', | 'TorchDataLoader', | ||||
'PaddleDataLoader', | 'PaddleDataLoader', | ||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'OneflowDataLoader', | |||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
'prepare_oneflow_dataloader', | |||||
"prepare_dataloader" | |||||
"prepare_dataloader", | |||||
"OverfitDataLoader" | |||||
] | ] | ||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | ||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
from .prepare_dataloader import prepare_dataloader | |||||
from .oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader | |||||
from .prepare_dataloader import prepare_dataloader | |||||
from .utils import OverfitDataLoader |
@@ -47,38 +47,35 @@ class JittorDataLoader: | |||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | ||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | ||||
* callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn | * callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn | ||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | ||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param buffer_size: 每个进程占用的内存空间,默认为 512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 | |||||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | |||||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | ||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | collate_fn: Union[None, str, Callable] = "auto") -> None: | ||||
""" | |||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 | |||||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | |||||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
""" | |||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | # TODO 验证支持replacesampler (以后完成) 增加Sampler | ||||
# 将内部dataset批次设置为1 | # 将内部dataset批次设置为1 | ||||
if isinstance(dataset, Dataset): | if isinstance(dataset, Dataset): | ||||
@@ -136,20 +133,20 @@ class JittorDataLoader: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
:param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` | |||||
方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 | |||||
``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 | |||||
第 **0** 或 **1** 个元素。如果该 field 在数据中没有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 ``_single`` 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'Jittor', 'paddle', 'jittor', 'auto']`` ,分别代表,输出为 ``list`` , ``numpy.ndarray`` , | |||||
``Jittor.Tensor`` , ``paddle.Tensor`` , ``jittor.Var`` 类型。若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 ``"_single"`` 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -173,16 +170,14 @@ class JittorDataLoader: | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Example:: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||||
collator.set_ignore('field1', 'field2') | |||||
dataloader.set_ignore('field1', 'field2') | |||||
:param field_name: 需要调整的 field 的名称。如果 :class:`~fastNLP.core.Dataset` 的 :class:`~fastNLP.core.Dataset.__getitem__` | |||||
方法返回的是 dict 类型的,则可以直接使用对应的 field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 | |||||
``{'a': {'b': 1}}`` 中的使用 ``('a', 'b')`` 如果 ``__getitem__`` 返回的是 Sequence 类型的,则可以使用 *_0*, *_1* 表示序列中 | |||||
第 **0** 或 **1** 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -193,14 +188,14 @@ class JittorDataLoader: | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前 batch 的 idx | |||||
获取当前 ``batch`` 中每条数据对应的索引。 | |||||
:return: | |||||
:return: 当前 ``batch`` 数据的索引; | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, | |||||
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto", | collate_fn: Union[None, str, Callable] = "auto", | ||||
@@ -208,36 +203,37 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
-> Union[Dict[str, JittorDataLoader], JittorDataLoader]: | -> Union[Dict[str, JittorDataLoader], JittorDataLoader]: | ||||
""" | """ | ||||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | ||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_Jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||||
来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_jittor_dataloader`` 默认该 value 为 train 数据集, | |||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回。 | |||||
根据 ``ds_or_db`` 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataSet` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||||
来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集, | |||||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||||
最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回; | |||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | ||||
:class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | |||||
``Dict[key, JittorDataLoader]`` 的字典返回。 | |||||
:class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 | |||||
``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | |||||
``Dict[key, JittorDataLoader]`` 的字典返回; | |||||
:param ds_or_db: 可以有以下三种取值, | |||||
:param ds_or_db: 可以有以下三种取值: | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为实现了 :meth:`__getitem__` 和 :meth:`__len__` 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader`; | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | ||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | ||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | ||||
:param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 | :param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快 | ||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | ||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 | |||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。 | |||||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | :param stop_grad: 是否不使用梯度, 默认 ``True`` 。 | ||||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False`` | |||||
:param keep_numpy_array: 返回的数据是 :class:`np.array` 类型而不是 :class:`ittor.Var` 类型,默认为 ``False`` | |||||
:param endless: 是否让 :class:`JittorDataLoader` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | :param endless: 是否让 :class:`JittorDataLoader` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``. | ||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | ||||
@@ -246,11 +242,8 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。 | ||||
* callate_fn 为 ``'auto'`` 时,:class:`JittorDataLoader` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | * callate_fn 为 ``'auto'`` 时,:class:`JittorDataLoader` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | ||||
此时可以配套使用 :class:`JittorDataLoader` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | 此时可以配套使用 :class:`JittorDataLoader` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | ||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | ||||
:return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||||
``ds_or_db`` 变化而变化。 | |||||
""" | """ | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
@@ -258,7 +251,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -267,7 +260,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
else: | else: | ||||
dl_bundle[name] = JittorDataLoader(ds, | dl_bundle[name] = JittorDataLoader(ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -279,14 +272,14 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
ds_dict = {} | ds_dict = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl = JittorDataLoader(ds, | dl = JittorDataLoader(ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -296,7 +289,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
return ds_dict | return ds_dict | ||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
"OneflowDataLoader", | |||||
"prepare_oneflow_dataloader", | |||||
] | |||||
from .fdl import OneflowDataLoader, prepare_oneflow_dataloader |
@@ -0,0 +1,353 @@ | |||||
__all__ = [ | |||||
'OneflowDataLoader', | |||||
'prepare_oneflow_dataloader' | |||||
] | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any | |||||
from abc import ABC | |||||
from copy import deepcopy | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||||
from ..utils import _match_param | |||||
from ..utils import HasLenGetitemType | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
from oneflow.utils.data import DataLoader, Sampler, Dataset | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||||
class _FDataSet: | |||||
""" | |||||
提供给 ``OneflowDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回 | |||||
数据的下标 idx 。 | |||||
..note:: | |||||
需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法 | |||||
""" | |||||
def __init__(self, dataset) -> None: | |||||
self.dataset = dataset | |||||
def __getitem__(self, item: Union[int, list]) -> Tuple: | |||||
return (item, self.dataset[item]) | |||||
def __getattr__(self, item): | |||||
try: | |||||
return self.dataset.__getattribute__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
class OneflowDataLoader(DataLoader): | |||||
""" | |||||
提供给 ``oneflow`` 框架使用的 ``DataLoader`` 函数,``OneflowDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | |||||
若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。 | |||||
具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]`` | |||||
三种取值。 | |||||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 :class:`oneflow.utils.data.DataLoader` 自带的 collate_fn | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
def __init__(self, dataset, batch_size: int = 16, | |||||
shuffle: bool = False, sampler = None, batch_sampler = None, | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, **kwargs) -> None: | |||||
if isinstance(dataset, DataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _FDataSet): | |||||
dataset = _FDataSet(dataset) | |||||
if num_workers>0 and multiprocessing_context is None: | |||||
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||||
if batch_sampler is not None: | |||||
batch_size = 1 | |||||
shuffle = False | |||||
sampler = None | |||||
elif sampler is None: | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
shuffle = False | |||||
if isinstance(collate_fn, str): | |||||
if collate_fn == 'auto': | |||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||||
collate_fn = deepcopy(dataset.dataset.collator) | |||||
collate_fn.set_backend(backend="oneflow") | |||||
else: | |||||
collate_fn = Collator(backend="oneflow") | |||||
else: | |||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||||
dl_kwargs = _match_param(OneflowDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) | |||||
if dl_kwargs is None: | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
else: | |||||
super().__init__(**dl_kwargs) | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | |||||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||||
for indices, data in super().__iter__(): | |||||
self.cur_batch_indices = indices | |||||
yield data | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> Collator: | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回使用的 collator | |||||
""" | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return collator | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def _get_collator(self): | |||||
""" | |||||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||||
:return: | |||||
""" | |||||
collator = None | |||||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||||
collator = self.collate_fn.__wrapped__ | |||||
elif isinstance(self.collate_fn, Collator): | |||||
collator = self.collate_fn | |||||
return collator | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||||
dataloader.set_ignore('field1', 'field2') | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回使用的 collator | |||||
""" | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_ignore(*field_names) | |||||
return collator | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前 ``batch`` 中每条数据对应的索引。 | |||||
:return: 当前 ``batch`` 数据的索引; | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_oneflow_dataloader(ds_or_db, | |||||
batch_size: int = 16, | |||||
shuffle: bool = None, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, | |||||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
non_train_batch_size: int = None) \ | |||||
-> Union[OneflowDataLoader, Dict[str, OneflowDataLoader]]: | |||||
""" | |||||
``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader`` 对象, 详见 :class:`~fastNLP.OneflowDataLoader`。 | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_oneflow_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||||
帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`。 | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_oneflow_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||||
来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集, | |||||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||||
最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_oneflow_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||||
``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: OneflowDataLoader`` 组成 | |||||
``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||||
:param ds_or_db: 可以有以下三种取值, | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典; | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`; | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List 中的值为 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。 ``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` 。 | |||||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
from fastNLP.io import DataBundle | |||||
if isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.iter_datasets(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
else: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Mapping): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.items(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
else: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | |||||
dl = OneflowDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
) | |||||
return dl | |||||
else: | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -1,6 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'PaddleDataLoader', | |||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'PaddleDataLoader' | |||||
] | ] | ||||
from .fdl import PaddleDataLoader, prepare_paddle_dataloader | from .fdl import PaddleDataLoader, prepare_paddle_dataloader |
@@ -39,7 +39,7 @@ class _PaddleDataset(Dataset): | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
try: | try: | ||||
self.dataset.__getattribute__(item) | |||||
return self.dataset.__getattribute__(item) | |||||
except Exception as e: | except Exception as e: | ||||
raise e | raise e | ||||
@@ -53,6 +53,7 @@ class PaddleDataLoader(DataLoader): | |||||
1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可, | 1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可, | ||||
当不使用 :class:`~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动 | 当不使用 :class:`~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动 | ||||
探测数据的类型并判断能否 pad 。此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。 | 探测数据的类型并判断能否 pad 。此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。 | ||||
Example:: | Example:: | ||||
from fastNLP import PaddleDataLoader | from fastNLP import PaddleDataLoader | ||||
@@ -76,9 +77,46 @@ class PaddleDataLoader(DataLoader): | |||||
.. note:: | .. note:: | ||||
当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。 | 当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。 | ||||
3. 当 collate_fn 为 ``Callable`` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
3. 当 collate_fn 为 :class:`Callable` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | ||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param feed_list: feed Tensor list. | |||||
这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||||
应该被设置。 默认为 ``None `` 。 | |||||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, | |||||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||||
默认值为 ``True`` 。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 | |||||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | """ | ||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
@@ -88,45 +126,7 @@ class PaddleDataLoader(DataLoader): | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | worker_init_fn: Callable = None, persistent_workers=False) -> None: | ||||
""" | |||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param feed_list: feed Tensor list。 | |||||
这个张量能被 :code:`paddle.static.data()` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||||
应该被设置。 默认为 ``None`` | |||||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None。 | |||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, 每个设备上的返回值值为 str -> Tensor 的 dict, | |||||
其中 dict 的 key 为每个 fed Tensors 的名字。如果 :attr:`return_list` 为 ``True`` , 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` | |||||
只能在动态图情况下设置为 ``True`` 。默认值为 ``True`` 。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否将数据打乱,若``shuffle=True`` 则会将dataset打乱;若否则什么也不做。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 | |||||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
# FastNLP Datset, collate_fn not None | # FastNLP Datset, collate_fn not None | ||||
if isinstance(dataset, FDataSet) and collate_fn is None: | if isinstance(dataset, FDataSet) and collate_fn is None: | ||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | ||||
@@ -137,9 +137,11 @@ class PaddleDataLoader(DataLoader): | |||||
if batch_sampler is None: | if batch_sampler is None: | ||||
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, | batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
batch_size = 1 | |||||
shuffle = False | |||||
drop_last = False | |||||
# 因为无论如何传给 DataLoader 的 batch_sampler 都不是 None | |||||
# 所以要恢复默认值防止报错 | |||||
batch_size = 1 | |||||
shuffle = False | |||||
drop_last = False | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
@@ -184,20 +186,20 @@ class PaddleDataLoader(DataLoader): | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'Paddle', 'paddle', 'paddle', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
Paddle.Tensor, paddle.Tensor, paddle.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -221,15 +223,14 @@ class PaddleDataLoader(DataLoader): | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Example:: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||||
collator.set_ignore('field1', 'field2') | |||||
dataloader.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -258,58 +259,59 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
non_train_batch_size: int = None) \ | non_train_batch_size: int = None) \ | ||||
-> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: | -> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: | ||||
""" | """ | ||||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 | |||||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader`` 对象, 详见 :class:`~fastNLP.PaddleDataLoader`。 | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | ||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_paddle_dataloader`` 会将除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | ||||
来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_Paddle_dataloader`` 默认该 value 为 train 数据集, | |||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||||
来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_Paddle_dataloader`` 默认该 value 为训练数据集, | |||||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||||
最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | ||||
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | |||||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||||
``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_paddle_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | |||||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||||
:param ds_or_db: 可以有以下三种取值, | :param ds_or_db: 可以有以下三种取值, | ||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; | |||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | |||||
这个张量能被 :code:`paddle.static.data()` 创建。 如果:attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||||
应该被设置。 默认为 ``None `` | |||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||||
:param feed_list: feed Tensor list. | |||||
这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list` | |||||
应该被设置。 默认为 ``None `` 。 | |||||
:param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None. | |||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串 | |||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。 | |||||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, | :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`, | ||||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||||
默认值为 ``True`` 。 | |||||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。 | |||||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` . | |||||
默认值为 ``True`` 。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | ||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | ||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | ||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 `use_buffer_reader=True`` ,那么 ``PaddleDataLoader` `会异步的预取下一个 batch 的 | |||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的 | |||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。 | |||||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | ||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。 | |||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。 | |||||
:param timeout: 从子进程的输出队列获取数据的超时值 | :param timeout: 从子进程的输出队列获取数据的超时值 | ||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | ||||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | ||||
@@ -324,7 +326,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | ||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, | batch_sampler=batch_sampler, batch_size=batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=True if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
use_buffer_reader=use_buffer_reader, | use_buffer_reader=use_buffer_reader, | ||||
@@ -335,7 +337,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
use_buffer_reader=use_buffer_reader, | use_buffer_reader=use_buffer_reader, | ||||
@@ -348,7 +350,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -357,7 +360,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -367,7 +370,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
@@ -9,41 +9,44 @@ import sys | |||||
from .torch_dataloader import prepare_torch_dataloader | from .torch_dataloader import prepare_torch_dataloader | ||||
from .paddle_dataloader import prepare_paddle_dataloader | from .paddle_dataloader import prepare_paddle_dataloader | ||||
from .jittor_dataloader import prepare_jittor_dataloader | from .jittor_dataloader import prepare_jittor_dataloader | ||||
from .oneflow_dataloader import prepare_oneflow_dataloader | |||||
from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | ||||
from ..log import logger | from ..log import logger | ||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False, | |||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
backend: str = 'auto'): | backend: str = 'auto'): | ||||
""" | """ | ||||
自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | ||||
返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | ||||
:func:`~fastNLP.prepare_torch_dataloader` 或 :func:`~fastNLP.prepare_paddle_dataloader` 等。 | |||||
:func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 或 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 等。 | |||||
:param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 | :param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 | ||||
* 为单个数据集对象时,返回一个 DataLoader 。 | * 为单个数据集对象时,返回一个 DataLoader 。 | ||||
* 为数据集对象序列时,返回一个序列的 DataLoader 。 | * 为数据集对象序列时,返回一个序列的 DataLoader 。 | ||||
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 | |||||
* 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 :class:`Dict` 类型的数据。 | |||||
:param batch_size: 批次大小。 | :param batch_size: 批次大小。 | ||||
:param shuffle: 是否打乱数据集。 | |||||
:param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param drop_last: 当最后一个 batch 不足 ``batch_size`` 数量的是否,是否丢弃。 | |||||
:param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | ||||
* 为 ``auto`` 时,使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 | * 为 ``auto`` 时,使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 | ||||
* 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | |||||
* 为 :class:`Callable` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | |||||
* 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | * 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | ||||
:param num_workers: 使用多少进程进行数据的 fetch 。 | :param num_workers: 使用多少进程进行数据的 fetch 。 | ||||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 | |||||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor", "oneflow"]`` 四种类型。 | |||||
* 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 | |||||
* 为 ``auto`` 时,首先根据环境变量 ``"FASTNLP_BACKEND"`` 进行判断;如果没有设置则通过当前 | |||||
``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 | ``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 | ||||
``backend`` ,则按照下述的方式处理。 | ``backend`` ,则按照下述的方式处理。 | ||||
* 为 ``torch`` 时,使用 :func:`~fastNLP.prepare_torch_dataloader` 。 | |||||
* 为 ``paddle`` 时,使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 | |||||
* 为 ``jittor`` 时,使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 | |||||
* 为 ``torch`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 。 | |||||
* 为 ``paddle`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 。 | |||||
* 为 ``jittor`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_jittor_dataloader` 。 | |||||
* 为 ``oneflow`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_oneflow_dataloader` 。 | |||||
:return | :return | ||||
""" | """ | ||||
@@ -60,6 +63,10 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro | |||||
prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, | prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, | ||||
num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
elif backend == 'oneflow': | |||||
return prepare_oneflow_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, | |||||
num_workers=num_workers, shuffle=shuffle, sampler=None, | |||||
batch_size=batch_size) | |||||
else: | else: | ||||
raise ValueError(f"Currently we do not support backend:{backend}.") | raise ValueError(f"Currently we do not support backend:{backend}.") | ||||
@@ -58,9 +58,41 @@ class TorchDataLoader(DataLoader): | |||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | ||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | ||||
* callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn | * callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn | ||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | ||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, | def __init__(self, dataset, batch_size: int = 16, | ||||
@@ -70,44 +102,16 @@ class TorchDataLoader(DataLoader): | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, **kwargs) -> None: | persistent_workers: bool = False, **kwargs) -> None: | ||||
""" | |||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
if isinstance(dataset, DataSet) and collate_fn is None: | if isinstance(dataset, DataSet) and collate_fn is None: | ||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | ||||
if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
if num_workers>0 and multiprocessing_context is None: | |||||
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||||
if batch_sampler is not None: | if batch_sampler is not None: | ||||
batch_size = 1 | batch_size = 1 | ||||
shuffle = False | shuffle = False | ||||
@@ -150,20 +154,20 @@ class TorchDataLoader(DataLoader): | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -187,15 +191,14 @@ class TorchDataLoader(DataLoader): | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Example:: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略:: | |||||
collator.set_ignore('field1', 'field2') | |||||
dataloader.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回使用的 collator | |||||
""" | """ | ||||
collator = self._get_collator() | collator = self._get_collator() | ||||
if isinstance(collator, Collator): | if isinstance(collator, Collator): | ||||
@@ -215,7 +218,7 @@ class TorchDataLoader(DataLoader): | |||||
def prepare_torch_dataloader(ds_or_db, | def prepare_torch_dataloader(ds_or_db, | ||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = False, | |||||
shuffle: bool = None, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | ||||
@@ -227,55 +230,56 @@ def prepare_torch_dataloader(ds_or_db, | |||||
non_train_batch_size: int = None) \ | non_train_batch_size: int = None) \ | ||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: | -> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: | ||||
""" | """ | ||||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.TorchDataLoader`。 | |||||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader`` 对象, 详见 :class:`~fastNLP.TorchDataLoader`。 | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | ||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_torch_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来 | |||||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | ||||
来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集, | |||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集, | |||||
会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。 | |||||
最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | ||||
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||||
``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数, | |||||
其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||||
``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
:param ds_or_db: 可以有以下三种取值, | :param ds_or_db: 可以有以下三种取值, | ||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.TorchDataLoader` | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典; | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`; | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | ||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | 默认为None, 当其不为 None 时, shuffle 参数无效。 | ||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | ||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | ||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | ||||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | ||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | ||||
:param timeout: 子进程的输出队列获取数据的超时值 | :param timeout: 子进程的输出队列获取数据的超时值 | ||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | :param multiprocessing_context: 多进程的上下文环境 | ||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | ||||
""" | """ | ||||
@@ -287,7 +291,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -297,7 +301,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | sampler=non_train_sampler if non_train_sampler else sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
@@ -313,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -323,7 +327,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | sampler=non_train_sampler if non_train_sampler else sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
@@ -337,7 +341,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | elif isinstance(ds_or_db, HasLenGetitemType): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
@@ -101,6 +101,19 @@ class MixDataLoader(DataLoader): | |||||
""" | """ | ||||
针对以下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | 针对以下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | ||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||||
接一个的 sample 完所有数据。 | |||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机 sample | |||||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 ``datasets`` 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||||
``sampler``, ``drop_last``, ``ds_ratio`` 均无效。 | |||||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | ||||
接一个的 sample 完所有数据。 | 接一个的 sample 完所有数据。 | ||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | ||||
@@ -111,6 +124,40 @@ class MixDataLoader(DataLoader): | |||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | ||||
sampler, drop_last, ds_ratio 均无效。 | sampler, drop_last, ds_ratio 均无效。 | ||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class:`~fastNLP.core.collators.Collator` 作为其默认值, | |||||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||||
* collate_fn 为 :class:`Callable` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||||
``[None, str, Dict[str, "Sampler"]]``: | |||||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler]`` 为 datasets 字典的每个 dataset 初始化了一个 Sampler。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 ``None`` 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||||
到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 | |||||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||||
到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度 ``max_len``。 | |||||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||||
""" | """ | ||||
def __init__(self, datasets: Dict = None, mode: str = 'sequential', | def __init__(self, datasets: Dict = None, mode: str = 'sequential', | ||||
@@ -119,55 +166,6 @@ class MixDataLoader(DataLoader): | |||||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | num_workers: int = 0, batch_size: int = 16, drop_last=False, | ||||
ds_ratio: Union[None, str, Dict[str, float]] = None, | ds_ratio: Union[None, str, Dict[str, float]] = None, | ||||
pin_memory: bool = False) -> None: | pin_memory: bool = False) -> None: | ||||
""" | |||||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||||
接一个的 sample 完所有数据。 | |||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||||
sampler, drop_last, ds_ratio 均无效。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class:`~fastNLP.core.collators.Collator` 作为其默认值, | |||||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||||
* collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||||
``[None, str, Dict[str, "Sampler"]]``: | |||||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler]`` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||||
到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。 | |||||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||||
到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 | |||||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||||
""" | |||||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | # sampler 为 dict,则判断是否与 datasets 的 key 相同 | ||||
if isinstance(sampler, Dict): | if isinstance(sampler, Dict): | ||||
for key in datasets.keys(): | for key in datasets.keys(): | ||||
@@ -1,4 +1,5 @@ | |||||
from typing import Callable, Any, Union | |||||
import os | |||||
from typing import Callable, Any, Union, Sequence | |||||
from abc import ABC | from abc import ABC | ||||
import inspect | import inspect | ||||
import ast | import ast | ||||
@@ -6,13 +7,14 @@ import ast | |||||
from ..log import logger | from ..log import logger | ||||
from ..utils.cache_results import get_func_calls, truncate_start_blanks | from ..utils.cache_results import get_func_calls, truncate_start_blanks | ||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | |||||
"indice_collate_wrapper", | |||||
"OverfitDataLoader" | |||||
] | ] | ||||
def indice_collate_wrapper(func:Callable): | def indice_collate_wrapper(func:Callable): | ||||
""" | """ | ||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
其功能是封装一层 collate_fn,将 dataset 取到的 tuple 数据分离开,将 idx 打包为 indices。 | |||||
:param func: 需要修饰的函数 | :param func: 需要修饰的函数 | ||||
:return: | :return: | ||||
@@ -111,6 +113,40 @@ class HasLenGetitemType(ABC): | |||||
return NotImplemented | return NotImplemented | ||||
class OverfitDataLoader: | |||||
""" | |||||
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 ``dataloader`` 中取出部分数据,来让 Trainer 实现 overfit 的功能; | |||||
""" | |||||
def __init__(self, dataloader, overfit_batches: int, batches=None): | |||||
# batches 参数是给重新初始化dataloader使用的 | |||||
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; | |||||
if batches is None: | |||||
self.batches = [] | |||||
self.overfit_batches = int(overfit_batches) | |||||
if self.overfit_batches > len(dataloader): | |||||
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx < self.overfit_batches or self.overfit_batches <= -1: | |||||
self.batches.append(batch) | |||||
else: | |||||
assert isinstance(batches, list) | |||||
self.batches = batches | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
def __iter__(self): | |||||
for batch in self.batches: | |||||
yield batch | |||||
def __getattr__(self, item): | |||||
return getattr(self.dataloader, item) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
def demo(*args, **kwargs): | def demo(*args, **kwargs): | ||||
pass | pass | ||||
@@ -1,7 +1,7 @@ | |||||
r""" | r""" | ||||
:class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, | :class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, | ||||
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.instance` ), | |||||
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.field` )。 | |||||
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.dataset.instance` ), | |||||
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.dataset.field` )。 | |||||
.. csv-table:: Following is a demo layout of DataSet | .. csv-table:: Following is a demo layout of DataSet | ||||
:header: "sentence", "words", "seq_len" | :header: "sentence", "words", "seq_len" | ||||
@@ -11,7 +11,7 @@ r""" | |||||
"Third instance .", "[Third, instance, .]", 3 | "Third instance .", "[Third, instance, .]", 3 | ||||
"...", "[...]", "..." | "...", "[...]", "..." | ||||
在 fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 | |||||
在 fastNLP 内部每一行是一个 :class:`~fastNLP.core.dataset.Instance` 对象; 每一列是一个 :class:`~fastNLP.core.dataset.FieldArray` 对象。 | |||||
---------------------------- | ---------------------------- | ||||
1.DataSet的创建 | 1.DataSet的创建 | ||||
@@ -65,7 +65,7 @@ r""" | |||||
2.DataSet 与预处理 | 2.DataSet 与预处理 | ||||
-------------------------------------- | -------------------------------------- | ||||
常见的预处理有如下几种 | |||||
常见的预处理有如下几种: | |||||
2.1 从某个文本文件读取内容 | 2.1 从某个文本文件读取内容 | ||||
-------------------------------------- | -------------------------------------- | ||||
@@ -97,10 +97,10 @@ r""" | |||||
# 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, | # 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, | ||||
# 详细内容可以见 :class:`~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 | # 详细内容可以见 :class:`~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 | ||||
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', | dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', | ||||
progress_des='Main',progress_bar='rich') | |||||
progress_des='Main',progress_bar='rich') | |||||
# 或使用DataSet.apply_field() | # 或使用DataSet.apply_field() | ||||
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', | dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', | ||||
progress_des='Main',progress_bar='rich') | |||||
progress_des='Main',progress_bar='rich') | |||||
# 除了匿名函数,也可以定义函数传递进去 | # 除了匿名函数,也可以定义函数传递进去 | ||||
def get_words(instance): | def get_words(instance): | ||||
sentence = instance['sentence'] | sentence = instance['sentence'] | ||||
@@ -145,8 +145,8 @@ r""" | |||||
# DataSet 的长度 | # DataSet 的长度 | ||||
len(dataset) | len(dataset) | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"DataSet", | "DataSet", | ||||
"ApplyResultException" | "ApplyResultException" | ||||
@@ -255,34 +255,31 @@ def _multi_proc(ds, _apply_field, func, counter, queue): | |||||
class DataSet: | class DataSet: | ||||
r""" | r""" | ||||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | |||||
""" | |||||
fastNLP的数据容器。 | |||||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | |||||
r""" | |||||
初始化 ``DataSet``, fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
Example:: | |||||
:param data: 初始化的内容, 其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} | |||||
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] | |||||
ds = DataSet(data) | |||||
ds = DataSet(data1) | |||||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
Example:: | |||||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]]`` 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} | |||||
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] | |||||
ds = DataSet(data) | |||||
ds = DataSet(data1) | |||||
:param data: 初始化的内容,其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
""" | |||||
* 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
""" | |||||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | |||||
self.field_arrays = {} | self.field_arrays = {} | ||||
self._collator = Collator() | self._collator = Collator() | ||||
if data is not None: | if data is not None: | ||||
@@ -429,10 +426,9 @@ class DataSet: | |||||
def append(self, instance: Instance) -> None: | def append(self, instance: Instance) -> None: | ||||
r""" | r""" | ||||
将一个 instance 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.Instance` | |||||
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field。 | |||||
将一个 ``instance`` 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.core.dataset.Instance` | |||||
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field; | |||||
""" | """ | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
# DataSet has no field yet | # DataSet has no field yet | ||||
@@ -445,7 +441,7 @@ class DataSet: | |||||
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." | "DataSet object has {} fields, but attempt to append an Instance object with {} fields." | ||||
.format(len(self.field_arrays), len(instance.fields))) | .format(len(self.field_arrays), len(instance.fields))) | ||||
for name, field in instance.items(): | for name, field in instance.items(): | ||||
assert name in self.field_arrays | |||||
assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}' | |||||
try: | try: | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -454,10 +450,10 @@ class DataSet: | |||||
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | ||||
r""" | r""" | ||||
将 fieldarray 添加到 DataSet 中. | |||||
将 ``fieldarray`` 添加到 DataSet 中. | |||||
:param field_name: 新加入的 field 的名称 | |||||
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` | |||||
:param field_name: 新加入的 field 的名称; | |||||
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` ; | |||||
:return: | :return: | ||||
""" | """ | ||||
if not isinstance(fieldarray, FieldArray): | if not isinstance(fieldarray, FieldArray): | ||||
@@ -472,8 +468,8 @@ class DataSet: | |||||
r""" | r""" | ||||
新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 | 新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 | ||||
:param field_name: 新增的 field 的名称 | |||||
:param fields: 需要新增的 field 的内容 | |||||
:param field_name: 新增的 field 的名称; | |||||
:param fields: 需要新增的 field 的内容; | |||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
@@ -484,9 +480,9 @@ class DataSet: | |||||
def delete_instance(self, index: int): | def delete_instance(self, index: int): | ||||
r""" | r""" | ||||
删除第 ``index `` 个 Instance | |||||
删除第 ``index`` 个 Instance | |||||
:param index: 需要删除的 instanc e的 index,序号从 `0` 开始。 | |||||
:param index: 需要删除的 instance 的 index,序号从 `0` 开始。 | |||||
""" | """ | ||||
assert isinstance(index, int), "Only integer supported." | assert isinstance(index, int), "Only integer supported." | ||||
if len(self) <= index: | if len(self) <= index: | ||||
@@ -500,9 +496,9 @@ class DataSet: | |||||
def delete_field(self, field_name: str): | def delete_field(self, field_name: str): | ||||
r""" | r""" | ||||
删除名为 field_name 的 field | |||||
删除名为 ``field_name`` 的 field | |||||
:param field_name: 需要删除的 field 的名称. | |||||
:param field_name: 需要删除的 field 的名称; | |||||
""" | """ | ||||
if self.has_field(field_name): | if self.has_field(field_name): | ||||
self.field_arrays.pop(field_name) | self.field_arrays.pop(field_name) | ||||
@@ -512,11 +508,11 @@ class DataSet: | |||||
def copy_field(self, field_name: str, new_field_name: str): | def copy_field(self, field_name: str, new_field_name: str): | ||||
r""" | r""" | ||||
深度 copy 名为 field_name 的 field 到 new_field_name | |||||
深度 copy 名为 ``field_name`` 的 field 到 ``new_field_name`` | |||||
:param field_name: 需要 copy 的 field。 | |||||
:param new_field_name: copy 生成的 field 名称 | |||||
:return: self | |||||
:param field_name: 需要 copy 的 field; | |||||
:param new_field_name: copy 生成的 field 名称; | |||||
:return: 数据集自身; | |||||
""" | """ | ||||
if not self.has_field(field_name): | if not self.has_field(field_name): | ||||
raise KeyError(f"Field:{field_name} not found in DataSet.") | raise KeyError(f"Field:{field_name} not found in DataSet.") | ||||
@@ -527,10 +523,10 @@ class DataSet: | |||||
def has_field(self, field_name: str) -> bool: | def has_field(self, field_name: str) -> bool: | ||||
r""" | r""" | ||||
判断 DataSet 中是否有名为 field_name 这个 field | |||||
判断 DataSet 中是否有名为 ``field_name`` 这个 field | |||||
:param field_name: field 的名称 | |||||
:return: 表示是否有名为 field_name 这个 field | |||||
:param field_name: field 的名称; | |||||
:return: 表示是否有名为 ``field_name`` 这个 field; | |||||
""" | """ | ||||
if isinstance(field_name, str): | if isinstance(field_name, str): | ||||
return field_name in self.field_arrays | return field_name in self.field_arrays | ||||
@@ -538,10 +534,10 @@ class DataSet: | |||||
def get_field(self, field_name: str) -> FieldArray: | def get_field(self, field_name: str) -> FieldArray: | ||||
r""" | r""" | ||||
获取 field_name 这个 field | |||||
获取名为 ``field_name`` 的 field | |||||
:param field_name: field 的名称 | |||||
:return: :class:`~fastNLP.FieldArray` | |||||
:param field_name: field 的名称; | |||||
:return: 一个 :class:`~fastNLP.core.dataset.FieldArray` 对象; | |||||
""" | """ | ||||
if field_name not in self.field_arrays: | if field_name not in self.field_arrays: | ||||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | raise KeyError("Field name {} not found in DataSet".format(field_name)) | ||||
@@ -549,17 +545,13 @@ class DataSet: | |||||
def get_all_fields(self) -> dict: | def get_all_fields(self) -> dict: | ||||
r""" | r""" | ||||
返回一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||||
:return: 返回如上所述的字典 | |||||
:return: 一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.core.dataset.FieldArray` 对象。 | |||||
""" | """ | ||||
return self.field_arrays | return self.field_arrays | ||||
def get_field_names(self) -> list: | def get_field_names(self) -> list: | ||||
r""" | r""" | ||||
返回一个 list,包含所有 field 的名字 | |||||
:return: 返回如上所述的列表 | |||||
:return: 一个 list,包含所有 field 的名字 | |||||
""" | """ | ||||
return sorted(self.field_arrays.keys()) | return sorted(self.field_arrays.keys()) | ||||
@@ -575,8 +567,8 @@ class DataSet: | |||||
r""" | r""" | ||||
将某个 field 重新命名. | 将某个 field 重新命名. | ||||
:param field_name: 原来的 field 名称。 | |||||
:param new_field_name: 修改为 new_name。 | |||||
:param field_name: 原来的 field 名称; | |||||
:param new_field_name: 修改为 new_name; | |||||
""" | """ | ||||
if field_name in self.field_arrays: | if field_name in self.field_arrays: | ||||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | ||||
@@ -589,13 +581,13 @@ class DataSet: | |||||
new_field_name: str = None, num_proc: int = 0, | new_field_name: str = None, num_proc: int = 0, | ||||
progress_desc: str = None, progress_bar: str = 'rich'): | progress_desc: str = None, progress_bar: str = 'rich'): | ||||
r""" | r""" | ||||
将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` | |||||
将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` | |||||
中。 | 中。 | ||||
:param field_name: 传入 ``func`` 的 ``field`` 名称; | |||||
:param func: 对指定 ``field`` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容; | |||||
:param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; | |||||
:param field_name: 传入 ``func`` 的 field 名称; | |||||
:param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 | :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 | ||||
应的 ``field`` 中,注意如果名称与已有的 ``field`` 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` ; | |||||
应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; | |||||
:param num_proc: 使用进程的数量。 | :param num_proc: 使用进程的数量。 | ||||
.. note:: | .. note:: | ||||
@@ -603,8 +595,8 @@ class DataSet: | |||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | ||||
``func`` 函数中的打印将不会输出。 | ``func`` 函数中的打印将不会输出。 | ||||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``; | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; | |||||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||||
:return: 从函数 ``func`` 中得到的返回值; | :return: 从函数 ``func`` 中得到的返回值; | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -625,26 +617,27 @@ class DataSet: | |||||
modify_fields: bool = True, num_proc: int = 0, | modify_fields: bool = True, num_proc: int = 0, | ||||
progress_desc: str = None, progress_bar: str = 'rich'): | progress_desc: str = None, progress_bar: str = 'rich'): | ||||
r""" | r""" | ||||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||||
func 可以返回一个或多个 field 上的结果。 | |||||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的 field 传给 ``func``,并获取它的返回值。 | |||||
``func`` 可以返回一个或多个 field 上的结果。 | |||||
.. note:: | .. note:: | ||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param field_name: 传入func的是哪个field。 | |||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||||
:param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; | |||||
:param field_name: 传入 ``func`` 的 fiel` 名称; | |||||
:param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 | |||||
应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; | |||||
:param num_proc: 使用进程的数量。 | :param num_proc: 使用进程的数量。 | ||||
.. note:: | .. note:: | ||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | ||||
``func`` 函数中的打印将不会输出。 | ``func`` 函数中的打印将不会输出。 | ||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 | |||||
:return Dict[str:Field]: 返回一个字典 | |||||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; | |||||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||||
:return: 返回一个字典 | |||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
if not self.has_field(field_name=field_name): | if not self.has_field(field_name=field_name): | ||||
@@ -747,7 +740,7 @@ class DataSet: | |||||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | def apply_more(self, func: Callable = None, modify_fields: bool = True, | ||||
num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): | num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): | ||||
r""" | r""" | ||||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||||
将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func``可以返回一个或多个 field 上的结果。 | |||||
.. note:: | .. note:: | ||||
``apply_more`` 与 ``apply`` 的区别: | ``apply_more`` 与 ``apply`` 的区别: | ||||
@@ -767,9 +760,9 @@ class DataSet: | |||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | ||||
``func`` 函数中的打印将不会输出。 | ``func`` 函数中的打印将不会输出。 | ||||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:return Dict[str:Field]: 返回一个字典 | |||||
:param progress_desc: 当 progress_bar 不为 ``None`` 时,可以显示当前正在处理的进度条名称 | |||||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||||
:return: 返回一个字典 | |||||
""" | """ | ||||
assert callable(func), "The func is not callable." | assert callable(func), "The func is not callable." | ||||
assert len(self) != 0, "Null DataSet cannot use apply()." | assert len(self) != 0, "Null DataSet cannot use apply()." | ||||
@@ -808,10 +801,11 @@ class DataSet: | |||||
def apply(self, func: Callable = None, new_field_name: str = None, | def apply(self, func: Callable = None, new_field_name: str = None, | ||||
num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): | num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): | ||||
""" | """ | ||||
将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func`` 仅能返回一个结果。 | |||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||||
盖之前的field。如果为None则不创建新的field。 | |||||
:param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 这个 field中 ,如果名称与已有的 field 相同,则覆 | |||||
盖之前的 field。如果为 ``None`` 则不创建新的 field。 | |||||
:param num_proc: 使用进程的数量。 | :param num_proc: 使用进程的数量。 | ||||
.. note:: | .. note:: | ||||
@@ -819,8 +813,8 @@ class DataSet: | |||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | ||||
``func`` 函数中的打印将不会输出。 | ``func`` 函数中的打印将不会输出。 | ||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: progress bar 显示的值,默认为空。 | |||||
:param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 | |||||
:param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称。 | |||||
""" | """ | ||||
assert callable(func), "The func you provide is not callable." | assert callable(func), "The func you provide is not callable." | ||||
assert len(self) != 0, "Null DataSet cannot use apply()." | assert len(self) != 0, "Null DataSet cannot use apply()." | ||||
@@ -838,10 +832,10 @@ class DataSet: | |||||
def add_seq_len(self, field_name: str, new_field_name='seq_len'): | def add_seq_len(self, field_name: str, new_field_name='seq_len'): | ||||
r""" | r""" | ||||
将使用 len() 直接对 field_name 中每个元素作用,将其结果作为 sequence length, 并放入 seq_len 这个 field。 | |||||
将使用 :func:`len` 直接对 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 ``new_field_name`` 这个 field。 | |||||
:param field_name: 需要处理的 field_name | :param field_name: 需要处理的 field_name | ||||
:param new_field_name: str. 新的 field_name | |||||
:param new_field_name: 新的 field_name | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.has_field(field_name=field_name): | if self.has_field(field_name=field_name): | ||||
@@ -852,10 +846,10 @@ class DataSet: | |||||
def drop(self, func: Callable, inplace=True): | def drop(self, func: Callable, inplace=True): | ||||
r""" | r""" | ||||
删除某些 Instance。 需要注意的时func 接受一个 Instance ,返回 bool 值。返回值为 True 时, | |||||
删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||||
该 Instance 会被移除或者不会包含在返回的 DataSet 中。 | 该 Instance 会被移除或者不会包含在返回的 DataSet 中。 | ||||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 True 时删除该 instance | |||||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||||
:param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 | :param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 | ||||
:return: DataSet | :return: DataSet | ||||
@@ -875,11 +869,11 @@ class DataSet: | |||||
def split(self, ratio: float, shuffle=True): | def split(self, ratio: float, shuffle=True): | ||||
r""" | r""" | ||||
将 DataSet 按照 ratio 的比例拆分,返回两个 DataSet | |||||
将 DataSet 按照 ``ratio`` 的比例拆分,返回两个 DataSet | |||||
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 `ratio` 这么多数据,第二个 DataSet 拥有 `(1-ratio)` 这么多数据 | |||||
:param shuffle: 在 split 前是否 shuffle 一下。为 False,返回的第一个 dataset 就是当前 dataset 中前 `ratio` 比例的数据, | |||||
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | |||||
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 ``ratio`` 比例的数据,第二个 DataSet 拥有 ``1-ratio`` 的数据; | |||||
:param shuffle: 在拆分前是否进行排序。为 False,返回的第一个 dataset 就是当前 dataset 中前 ``ratio`` 比例的数据; | |||||
:return: 拆分后的两个 DataSet; | |||||
""" | """ | ||||
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | ||||
assert isinstance(ratio, float) | assert isinstance(ratio, float) | ||||
@@ -906,9 +900,9 @@ class DataSet: | |||||
def save(self, path: str) -> None: | def save(self, path: str) -> None: | ||||
r""" | r""" | ||||
保存DataSet. | |||||
保存 DataSet。 | |||||
:param path: 将DataSet存在哪个路径 | |||||
:param path: 保存路径; | |||||
""" | """ | ||||
with open(path, 'wb') as f: | with open(path, 'wb') as f: | ||||
pickle.dump(self, f) | pickle.dump(self, f) | ||||
@@ -916,10 +910,10 @@ class DataSet: | |||||
@staticmethod | @staticmethod | ||||
def load(path: str): | def load(path: str): | ||||
r""" | r""" | ||||
从保存的 DataSet pickle文件的路径中读取DataSet | |||||
从保存的 DataSet pickle 文件的路径中读取 DataSet | |||||
:param path: 从哪里读取 DataSet | |||||
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | |||||
:param path: 读取路径; | |||||
:return: 读取出的 DataSet | |||||
""" | """ | ||||
with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
d = pickle.load(f) | d = pickle.load(f) | ||||
@@ -928,16 +922,16 @@ class DataSet: | |||||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | ||||
""" | """ | ||||
将当前 dataset 与输入的 dataset 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset | |||||
的 field_name 和 _collator 以当前 dataset 为准。当 dataset 中包含的 field 多于当前的 dataset,则多余的 field 会被忽略; | |||||
若 dataset 中未包含所有当前 dataset 含有 field,则会报错。 | |||||
将当前 DataSet 与输入的 ``dataset`` 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset | |||||
的 field_name 和 _collator 以当前 dataset 为准。若 ``dataset`` 中包含的 field 多于当前的 DataSet,则多余的 field 会被忽略; | |||||
若 ``dataset`` 中未包含所有当前 DataSet 含有 field,则会报错。 | |||||
:param dataset: 需要和当前 dataset concat的 dataset | |||||
:param inplace: 是否直接将 dataset 组合到当前 dataset 中 | |||||
:param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时,需要通过 field_mapping 把输入的 dataset 中的 | |||||
field 名称映射到当前 field. field_mapping 为 dict 类型,key 为 dataset 中的 field 名称,value 是需要映射成的名称 | |||||
:param dataset: 需要和当前 DataSet 拼接的 ``dataset``; | |||||
:param inplace: 是否直接将 ``dataset`` 组合到当前 DataSet 中; | |||||
:param field_mapping: 当传入的 ``dataset`` 中的 field 名称和当前 dataset 不一致时,需要通过 ``field_mapping`` 把输入的 ``dataset`` | |||||
中的 field 名称映射到当前 field。``field_mapping`` 为 dict 类型,key 为 11dataset`` 中的 field 名称,value 是需要映射成的名称 | |||||
:return: :class:`~fastNLP.core.dataset.DataSet`` | |||||
:return: :class:`~fastNLP.core.dataset.DataSet` | |||||
""" | """ | ||||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | assert isinstance(dataset, DataSet), "Can only concat two datasets." | ||||
@@ -966,7 +960,8 @@ class DataSet: | |||||
@classmethod | @classmethod | ||||
def from_pandas(cls, df): | def from_pandas(cls, df): | ||||
""" | """ | ||||
从 ``pandas.DataFrame`` 中读取数据转为 DataSet | |||||
从 :class:`pandas.DataFrame` 中读取并数据转化为 DataSet | |||||
:param df: 使用 pandas 读取的数据 | :param df: 使用 pandas 读取的数据 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -975,7 +970,7 @@ class DataSet: | |||||
def to_pandas(self): | def to_pandas(self): | ||||
""" | """ | ||||
将 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据 | |||||
将 DataSet 数据转为 :class:`pandas.DataFrame` 类型的数据 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -1003,23 +998,22 @@ class DataSet: | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | ||||
pad_fn: Callable = None) -> Collator: | pad_fn: Callable = None) -> Collator: | ||||
""" | """ | ||||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` | |||||
时该函数才有效。调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整。 | |||||
:param field_name: 需要调整的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||||
该值无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 | |||||
:param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||||
:class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||||
若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||||
:return: 返回自身的 collator; | |||||
""" | """ | ||||
if isinstance(self.collator, Collator): | if isinstance(self.collator, Collator): | ||||
self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | ||||
@@ -1030,16 +1024,14 @@ class DataSet: | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` | ``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` | ||||
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Example:: | |||||
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略:: | |||||
collator.set_ignore('field1', 'field2') | |||||
dataset.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||||
如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||||
:return: 返回自身的 collator; | |||||
""" | """ | ||||
if isinstance(self.collator, Collator): | if isinstance(self.collator, Collator): | ||||
self.collator.set_ignore(*field_names) | self.collator.set_ignore(*field_names) | ||||
@@ -14,15 +14,14 @@ import numpy as np | |||||
class FieldArray: | class FieldArray: | ||||
""" | |||||
:class:`~fastNLP.core.dataset.DatSet` 中用于表示列的数据类型。 | |||||
def __init__(self, name: str, content): | |||||
""" | |||||
初始化 FieldArray | |||||
:param name: 字符串的名称 | |||||
:param content: 任意类型的数据 | |||||
:param name: 字符串的名称 | |||||
:param content: 任意类型的数据 | |||||
""" | |||||
""" | |||||
def __init__(self, name: str, content): | |||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Empty fieldarray is not allowed.") | raise RuntimeError("Empty fieldarray is not allowed.") | ||||
_content = content | _content = content | ||||
@@ -36,18 +35,15 @@ class FieldArray: | |||||
def append(self, val: Any) -> None: | def append(self, val: Any) -> None: | ||||
r""" | r""" | ||||
:param val: 把该 val append 到 fieldarray。 | |||||
:return: | |||||
:param val: 把该 ``val`` 添加到 fieldarray 中。 | |||||
""" | """ | ||||
self.content.append(val) | self.content.append(val) | ||||
def pop(self, index: int) -> None: | def pop(self, index: int) -> None: | ||||
r""" | r""" | ||||
删除该 field 中 index 处的元素 | |||||
删除该 field 中 ``index`` 处的元素 | |||||
:param index: 从 ``0`` 开始的数据下标。 | :param index: 从 ``0`` 开始的数据下标。 | ||||
:return: | |||||
""" | """ | ||||
self.content.pop(index) | self.content.pop(index) | ||||
@@ -60,10 +56,10 @@ class FieldArray: | |||||
def get(self, indices: Union[int, List[int]]): | def get(self, indices: Union[int, List[int]]): | ||||
r""" | r""" | ||||
根据给定的 indices 返回内容。 | |||||
根据给定的 ``indices`` 返回内容。 | |||||
:param indices: 获取 indices 对应的内容。 | |||||
:return: 根据给定的 indices 返回的内容,可能是单个值 或 ``ndarray`` | |||||
:param indices: 获取 ``indices`` 对应的内容。 | |||||
:return: 根据给定的 ``indices`` 返回的内容,可能是单个值 或 :class:`numpy.ndarray` | |||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
if indices == -1: | if indices == -1: | ||||
@@ -80,16 +76,16 @@ class FieldArray: | |||||
r""" | r""" | ||||
返回长度 | 返回长度 | ||||
:return length: | |||||
:return: | |||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
r""" | r""" | ||||
依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 ``str`` 时,该方法才有用。 | |||||
依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 :class:`str` 时,该方法才有用。 | |||||
:param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。 | :param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。 | ||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 :class:`list`。 | |||||
:return: List[List[str]] or self | :return: List[List[str]] or self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -104,10 +100,11 @@ class FieldArray: | |||||
def int(self, inplace: bool = True): | def int(self, inplace: bool = True): | ||||
r""" | r""" | ||||
将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况: | 将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况: | ||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list`` ,``list`` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -126,10 +123,10 @@ class FieldArray: | |||||
r""" | r""" | ||||
将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况: | 将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况: | ||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||||
:return: | :return: | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -148,10 +145,10 @@ class FieldArray: | |||||
r""" | r""" | ||||
将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况 | 将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况 | ||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||||
:return: | :return: | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -169,12 +166,12 @@ class FieldArray: | |||||
def lower(self, inplace=True): | def lower(self, inplace=True): | ||||
r""" | r""" | ||||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||||
将本 field 中的值调用 ``cell.lower()``, 支持 field 中内容为以下两种情况 | |||||
* ['1', '2', ...](即 ``field`` 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list``中的值会被依次转换。) | |||||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -191,12 +188,12 @@ class FieldArray: | |||||
def upper(self, inplace=True): | def upper(self, inplace=True): | ||||
r""" | r""" | ||||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||||
将本 field 中的值调用 ``cell.upper()``, 支持 field 中内容为以下两种情况 | |||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
* ['1', '2', ...](即 field 中每个值为 :class:`str` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -211,11 +208,11 @@ class FieldArray: | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def value_count(self): | |||||
def value_count(self) -> Counter: | |||||
r""" | r""" | ||||
返回该 field 下不同 value的 数量。多用于统计 label 数量 | |||||
返回该 field 下不同 value 的数量。多用于统计 label 数量 | |||||
:return: Counter, key 是 label,value 是出现次数 | |||||
:return: 计数结果,key 是 label,value 是出现次数 | |||||
""" | """ | ||||
count = Counter() | count = Counter() | ||||
@@ -1,7 +1,6 @@ | |||||
r""" | r""" | ||||
instance 模块实现了 Instance 类在 fastNLP 中对应 sample。一个 sample 可以认为是一个 Instance 类型的对象。 | |||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 | |||||
instance 模块实现了 Instance 类,即在 fastNLP 中 sample 对应的类型。一个 sample 可以认为是一个 Instance 类型的对象。 | |||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset.dataset` 。 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -15,9 +14,9 @@ from fastNLP.core.utils.utils import pretty_table_printer | |||||
class Instance(Mapping): | class Instance(Mapping): | ||||
r""" | r""" | ||||
Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。 | Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。 | ||||
Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的 Example 所示:: | |||||
Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的代码所示:: | |||||
>>> instance = Instance(input="this is a demo sentence", label='good') # 请补充完整 | |||||
>>> instance = Instance(input="this is a demo sentence", label='good') | |||||
""" | """ | ||||
@@ -3,20 +3,34 @@ __all__ = [ | |||||
'TorchDriver', | 'TorchDriver', | ||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"DeepSpeedDriver", | |||||
"PaddleDriver", | "PaddleDriver", | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
"PaddleFleetDriver", | "PaddleFleetDriver", | ||||
"JittorDriver", | "JittorDriver", | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
"JittorMPIDriver", | "JittorMPIDriver", | ||||
'TorchSingleDriver', | |||||
'TorchDDPDriver', | |||||
'PaddleDriver', | |||||
'PaddleSingleDriver', | |||||
'PaddleFleetDriver', | |||||
'JittorDriver', | |||||
'JittorSingleDriver', | |||||
'JittorMPIDriver', | |||||
'OneflowDriver', | |||||
'OneflowSingleDriver', | |||||
'OneflowDDPDriver', | |||||
'torch_seed_everything', | 'torch_seed_everything', | ||||
'paddle_seed_everything', | 'paddle_seed_everything', | ||||
'oneflow_seed_everything', | |||||
'optimizer_state_to_device' | 'optimizer_state_to_device' | ||||
] | ] | ||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device | |||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, torch_seed_everything, optimizer_state_to_device | |||||
from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | ||||
from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | ||||
from .oneflow_driver import OneflowDriver, OneflowSingleDriver, OneflowDDPDriver, oneflow_seed_everything | |||||
from .driver import Driver | from .driver import Driver | ||||
@@ -1,6 +1,7 @@ | |||||
from typing import Union, Optional, List | from typing import Union, Optional, List | ||||
from .driver import Driver | from .driver import Driver | ||||
from ..utils import is_torch_module, is_paddle_module, is_jittor_module, is_oneflow_module | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | ||||
@@ -17,7 +18,19 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
if isinstance(driver, Driver): | if isinstance(driver, Driver): | ||||
return driver | return driver | ||||
if driver in {"torch", "fairscale"}: | |||||
if driver == "auto": | |||||
if is_torch_module(model): | |||||
driver = "torch" | |||||
elif is_paddle_module(model): | |||||
driver = "paddle" | |||||
elif is_jittor_module(model): | |||||
driver = "jittor" | |||||
elif is_oneflow_module(model): | |||||
driver = "oneflow" | |||||
else: | |||||
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | |||||
if driver in {"torch", "fairscale", "deepspeed"}: | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
return initialize_torch_driver(driver, device, model, **kwargs) | return initialize_torch_driver(driver, device, model, **kwargs) | ||||
elif driver in {"jittor"}: | elif driver in {"jittor"}: | ||||
@@ -26,6 +39,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
elif driver in {"paddle"}: | elif driver in {"paddle"}: | ||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
return initialize_paddle_driver(driver, device, model, **kwargs) | return initialize_paddle_driver(driver, device, model, **kwargs) | ||||
elif driver in {"oneflow"}: | |||||
from fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver import initialize_oneflow_driver | |||||
return initialize_oneflow_driver(driver, device, model, **kwargs) | |||||
else: | else: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " | raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " | ||||
"'jittor', 'paddle'].") | |||||
"'jittor', 'paddle', 'oneflow'].") |
@@ -6,6 +6,7 @@ from dataclasses import dataclass | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.dataloaders import JittorDataLoader | from fastNLP.core.dataloaders import JittorDataLoader | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection, nullcontext | from fastNLP.core.utils import apply_to_collection, nullcontext | ||||
@@ -39,20 +40,22 @@ __all__ = [ | |||||
class JittorDriver(Driver): | class JittorDriver(Driver): | ||||
r""" | r""" | ||||
``Jittor`` 框架的 ``Driver`` | |||||
``Jittor`` 框架的 ``Driver``,是 ``JittorSingleDevice`` 和 ``JittorMPIDriver`` 的父类。 | |||||
.. note:: | |||||
.. warning:: | |||||
这是一个正在开发中的功能,敬请期待。 | |||||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``JittorSingleDriver`` 和 ``TorchDDPDriver``,而不是 | |||||
该类本身; | |||||
.. todo:: | |||||
.. note:: | |||||
实现 fp16 的设置,且支持 cpu 和 gpu 的切换; | |||||
实现用于断点重训的 save 和 load 函数; | |||||
您可以在使用 ``JittorSingleDevice`` 和 ``JittorMPIDriver`` 时使用 ``JittorDriver`` 提供的接口; | |||||
:param model: 训练时使用的 **jittor** 模型; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:param jittor_kwargs: | |||||
""" | """ | ||||
def __init__(self, model, fp16: bool = False, **kwargs): | |||||
def __init__(self, model, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs): | |||||
if not isinstance(model, Module): | if not isinstance(model, Module): | ||||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " | raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " | ||||
f"`jittor.Module` type.") | f"`jittor.Module` type.") | ||||
@@ -64,12 +67,13 @@ class JittorDriver(Driver): | |||||
jt.flags.auto_mixed_precision_level = 0 | jt.flags.auto_mixed_precision_level = 0 | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self._auto_cast = nullcontext | self._auto_cast = nullcontext | ||||
self._jittor_kwargs = jittor_kwargs if jittor_kwargs is not None else {} | |||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
if not isinstance(dataloader, (Dataset, JittorDataLoader)): | |||||
if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)): | |||||
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") | ||||
if len(dataloader) == 0: | if len(dataloader) == 0: | ||||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | ||||
@@ -138,26 +142,12 @@ class JittorDriver(Driver): | |||||
num_consumed_batches = states.pop('num_consumed_batches') | num_consumed_batches = states.pop('num_consumed_batches') | ||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
else: | else: | ||||
@@ -34,10 +34,11 @@ class JittorMPIDriver(JittorDriver): | |||||
parallel_device: None, | parallel_device: None, | ||||
is_pull_by_jittor_run: bool = False, | is_pull_by_jittor_run: bool = False, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
jittor_kwargs: Dict = None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, jittor_kwargs=jittor_kwargs, **kwargs) | |||||
raise NotImplementedError("MPI for Jittor is not supported right now.") | raise NotImplementedError("MPI for Jittor is not supported right now.") | ||||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | self.is_pull_by_jittor_run = is_pull_by_jittor_run | ||||
@@ -25,15 +25,6 @@ class JittorSingleDriver(JittorDriver): | |||||
r""" | r""" | ||||
``Jittor`` 框架下用于 ``cpu`` 和单卡 ``gpu`` 运算的 ``Driver``。 | ``Jittor`` 框架下用于 ``cpu`` 和单卡 ``gpu`` 运算的 ``Driver``。 | ||||
.. note:: | |||||
这是一个正在开发中的功能,敬请期待。 | |||||
.. todo:: | |||||
支持 cpu 和 gpu 的切换; | |||||
实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | :param model: 传入给 ``Trainer`` 的 ``model`` 参数; | ||||
:param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; | :param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; | ||||
@@ -43,12 +34,13 @@ class JittorSingleDriver(JittorDriver): | |||||
表示在显卡设备上进行训练; | 表示在显卡设备上进行训练; | ||||
:param fp16: 是否开启 fp16; | :param fp16: 是否开启 fp16; | ||||
:param jittor_kwargs: | |||||
""" | """ | ||||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | |||||
def __init__(self, model, device=None, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs): | |||||
if device not in [None, "cpu", "gpu", "cuda"]: | if device not in [None, "cpu", "gpu", "cuda"]: | ||||
raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") | raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") | ||||
super(JittorSingleDriver, self).__init__(model, fp16) | |||||
super(JittorSingleDriver, self).__init__(model, fp16, jittor_kwargs=jittor_kwargs) | |||||
self.model_device = device if device is not None else "cpu" | self.model_device = device if device is not None else "cpu" | ||||
@@ -118,14 +110,14 @@ class JittorSingleDriver(JittorDriver): | |||||
if args.sampler is None: | if args.sampler is None: | ||||
sampler = RandomSampler(args.dataset, args.shuffle) | sampler = RandomSampler(args.dataset, args.shuffle) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, JittorRandomSampler): | |||||
elif type(args.sampler) is JittorRandomSampler: | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | if getattr(args.sampler, '_num_samples', None) is None \ | ||||
and getattr(args.sampler, 'rep', False) is False: | and getattr(args.sampler, 'rep', False) is False: | ||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | ||||
sampler = RandomSampler(args.sampler.dataset, shuffle=True) | sampler = RandomSampler(args.sampler.dataset, shuffle=True) | ||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, JittorSequentialSampler): | |||||
elif type(args.sampler) is JittorSequentialSampler: | |||||
# 需要替换为不要 shuffle 的。 | # 需要替换为不要 shuffle 的。 | ||||
sampler = RandomSampler(args.sampler.dataset, shuffle=False) | sampler = RandomSampler(args.sampler.dataset, shuffle=False) | ||||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | ||||
@@ -14,6 +14,7 @@ from fastNLP.envs import ( | |||||
FASTNLP_BACKEND_LAUNCH, | FASTNLP_BACKEND_LAUNCH, | ||||
FASTNLP_GLOBAL_SEED, | FASTNLP_GLOBAL_SEED, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
@@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler): | |||||
"or report this bug to us.") | "or report this bug to us.") | ||||
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | ||||
batch_sampler = getattr(dataloader, "sampler") | |||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||||
if isinstance(dataloader, JittorDataLoader): | if isinstance(dataloader, JittorDataLoader): | ||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | ||||
@@ -0,0 +1,18 @@ | |||||
__all__ = [ | |||||
"OneflowDDPDriver", | |||||
"OneflowSingleDriver", | |||||
"OneflowDriver", | |||||
"oneflow_seed_everything", | |||||
"optimizer_state_to_device" | |||||
] | |||||
from .ddp import OneflowDDPDriver | |||||
from .single_device import OneflowSingleDriver | |||||
from .oneflow_driver import OneflowDriver | |||||
from .utils import oneflow_seed_everything, optimizer_state_to_device | |||||
@@ -0,0 +1,323 @@ | |||||
import os | |||||
from typing import List, Optional, Union, Dict | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
import oneflow.env as dist_env | |||||
from oneflow.nn.parallel import DistributedDataParallel | |||||
from oneflow.utils.data import BatchSampler | |||||
__all__ = [ | |||||
"OneflowDDPDriver" | |||||
] | |||||
from .oneflow_driver import OneflowDriver | |||||
from fastNLP.core.drivers.oneflow_driver.utils import ( | |||||
replace_sampler, | |||||
replace_batch_sampler | |||||
) | |||||
from fastNLP.core.utils import check_user_specific_params | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \ | |||||
ReproducibleBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather, fastnlp_oneflow_broadcast_object | |||||
from .utils import _check_dataloader_args_for_distributed | |||||
class OneflowDDPDriver(OneflowDriver): | |||||
r""" | |||||
``OneflowDDPDriver`` 实现了动态图下使用 ``DistributedDataParallel`` 进行的数据并行分布式训练。 | |||||
.. note:: | |||||
您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; | |||||
``OneflowDDPDriver`` 目前支持两种启动方式: | |||||
1. 用户不做任何处理,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
2. 用户将模型通过 ``DistributedDataParallel`` 处理后,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
注意多机的启动强制要求用户在每一台机器上使用 ``python -m oneflow.distributed.launch`` 启动;因此我们不会在 ``OneflowDDPDriver`` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param parallel_device: 该参数无效,**fastNLP** 会自动获取当前进程的设备; | |||||
:param fp16: 是否开启 fp16 训练;目前该参数无效; | |||||
:param oneflow_kwargs: | |||||
* *ddp_kwargs* -- 用于 ``DistributedDataParallel`` 的其它参数,详情可查阅 **oneflow** 的官方文档; | |||||
""" | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: Optional["oneflow.device"], | |||||
fp16: bool = False, | |||||
oneflow_kwargs: Dict = None, | |||||
**kwargs | |||||
): | |||||
super(OneflowDDPDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs) | |||||
# oneflow 会自己初始化通信组,因此 parallel_device 实际上不起作用,可以通过 current_device 获取设备 | |||||
self.model_device = oneflow.device("cuda", oneflow.cuda.current_device()) | |||||
self._data_device = self.model_device | |||||
self.global_rank = int(os.environ["RANK"]) | |||||
self.world_size = int(os.environ["WORLD_SIZE"]) | |||||
self._ddp_kwargs = self._oneflow_kwargs.get("ddp_kwargs", {}) | |||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) | |||||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | |||||
logger.info("Notice your model has buffers and you are using `OneflowDDPDriver`, but you do not set " | |||||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | |||||
" to 'False' to avoid redundant data communication between different processes.") | |||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_ddpwrapped = False# hasattr(model, ) | |||||
def setup(self): | |||||
r""" | |||||
将模型用 ``DistributedDataParallel`` 进行处理; | |||||
""" | |||||
if self._has_setup: | |||||
return | |||||
self._has_setup = True | |||||
self.configure_ddp() | |||||
self.barrier() | |||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||||
# self._pids = [oneflow.tensor(0, dtype=oneflow.int).to(self.data_device) for _ in range(dist_env.get_world_size())] | |||||
# comm.all_gather(self._pids, oneflow.tensor(os.getpid(), dtype=oneflow.int).to(self.data_device)) | |||||
# local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||||
# if local_world_size is None: | |||||
# local_world_size = oneflow.tensor(int(os.environ.get("LOCAL_RANK")), dtype=oneflow.int).to(self.data_device) | |||||
# comm.all_reduce(local_world_size, op=dist_env.ReduceOp.MAX) | |||||
# local_world_size = local_world_size.tolist() + 1 | |||||
# node_rank = self.global_rank // local_world_size | |||||
# self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||||
# self._pids = self.tensor_to_numeric(self._pids) | |||||
def configure_ddp(self): | |||||
if not hasattr(self.model, "_ddp_state_for_reversed_params"): | |||||
self.model.to(self.model_device) | |||||
self.model = DistributedDataParallel( | |||||
# 注意这里的 self.model_device 是 `oneflow.device` type,因此 self.model_device.index; | |||||
self.model, | |||||
**self._ddp_kwargs | |||||
) | |||||
self._has_ddpwrapped = True | |||||
@property | |||||
def master_address(self) -> str: | |||||
return os.environ.get("MASTER_ADDR") | |||||
@property | |||||
def master_port(self) -> str: | |||||
return os.environ.get("MASTER_PORT") | |||||
@property | |||||
def world_size(self) -> int: | |||||
return self._world_size | |||||
@world_size.setter | |||||
def world_size(self, size: int): | |||||
self._world_size = size | |||||
@property | |||||
def global_rank(self) -> int: | |||||
return self._global_rank | |||||
@global_rank.setter | |||||
def global_rank(self, rank: int) -> None: | |||||
self._global_rank = rank | |||||
@property | |||||
def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响 | |||||
return int(os.environ.get("LOCAL_RANK", 0)) | |||||
@property | |||||
def data_device(self): | |||||
return self._data_device | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None, | |||||
reproducible: bool = False): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 OneflowDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | |||||
if dist is None: | |||||
if reproducible: | |||||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||||
else: | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) | |||||
return dataloader | |||||
# trainer | |||||
elif dist == "dist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
batch_sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
_check_dataloader_args_for_distributed(args, controller="Trainer") | |||||
sampler = RandomSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | |||||
elif dist == "unrepeatdist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
_check_dataloader_args_for_distributed(args, controller="Evaluator") | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank | |||||
) | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
raise ValueError( | |||||
"Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||||
def is_global_zero(self): | |||||
r""" | |||||
:return: 返回当前的进程是否在全局上是进程 0 ; | |||||
""" | |||||
return self.global_rank == 0 | |||||
def get_model_no_sync_context(self): | |||||
r""" | |||||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;该功能暂时无效,返回一个空的上下文环境; | |||||
""" | |||||
# TODO 暂时没有在 oneflow 中找到类似的功能; | |||||
from fastNLP.core.utils import nullcontext | |||||
return nullcontext | |||||
return self.model.no_sync | |||||
def unwrap_model(self): | |||||
r""" | |||||
:return: 返回原始模型; | |||||
""" | |||||
return self.model | |||||
def get_local_rank(self) -> int: | |||||
r""" | |||||
:return: 返回当前进程局部的进程编号; | |||||
""" | |||||
return self.local_rank | |||||
def barrier(self): | |||||
r""" | |||||
通过使用该函数来使得各个进程之间同步操作; | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||||
comm.barrier() | |||||
def is_distributed(self): | |||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``OneflowDDPDriver`` 来说,该函数一定返回 ``True``; | |||||
""" | |||||
return True | |||||
def broadcast_object(self, obj, src: int = 0, **kwargs): | |||||
r""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
:param int src: source 的 global rank 。 | |||||
:param int dst: target 的 global rank,可以是多个目标 rank | |||||
:param group: 所属的 group | |||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 | |||||
return | |||||
return fastnlp_oneflow_broadcast_object(obj, src, device=self.data_device) | |||||
def all_gather(self, obj) -> List: | |||||
r""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
example:: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 | |||||
:param group: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 | |||||
return [obj] | |||||
return fastnlp_oneflow_all_gather(obj) |
@@ -0,0 +1,306 @@ | |||||
import io | |||||
import pickle | |||||
import os | |||||
from typing import Any, List | |||||
from fastNLP.core.utils import apply_to_collection, get_oneflow_device | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.envs.env import FASTNLP_NO_SYNC | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
import oneflow.env as dist_env | |||||
PROTOCOL_VERSION = 1 | |||||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
if dst == my_rank: | |||||
if not gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must be specified on destination rank." | |||||
) | |||||
elif gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must NOT be specified " | |||||
"on non-destination ranks." | |||||
) | |||||
obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} | |||||
pickled_bytes = pickle.dumps(obj) | |||||
def fastnlp_oneflow_gather_object(obj, dst=0): | |||||
""" | |||||
从其它 rank gather 东西到 dst rank 。 | |||||
Example:: | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> fastnlp_oneflow_gather_object( | |||||
gather_objects[dist.get_rank()], | |||||
output if dist.get_rank() == 0 else None, | |||||
dst=0 | |||||
) | |||||
>>> # On rank 0 | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 | |||||
:param dst: 目标的 rank 。 | |||||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
return [obj] | |||||
if dist_env.get_rank() == dst: | |||||
object_gather_list = [None for _ in range(dist_env.get_world_size())] | |||||
else: | |||||
object_gather_list = None | |||||
# Ensure object_gather_list is specified appopriately. | |||||
my_rank = dist_env.get_rank() | |||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||||
# 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = oneflow.device("cuda") | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist_env.get_world_size() | |||||
object_sizes_tensor = oneflow.zeros(group_size, dtype=oneflow.long, device=current_device) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes. An all-gather is needed here despite this being a | |||||
# gather, since each rank needs to broadcast a tensor of the same (maximal) | |||||
# size. | |||||
comm.all_gather(object_size_list, local_size) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor = input_tensor.reshape(max_object_size) | |||||
# Avoid populating output tensors if the result won't be gathered on this rank. | |||||
if my_rank == dst: | |||||
coalesced_output_tensor = oneflow.empty( | |||||
max_object_size * group_size, dtype=oneflow.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
# All ranks call gather with equal-sized tensors. | |||||
comm.gather( | |||||
input_tensor, | |||||
gather_list=output_tensors if my_rank == dst else None, | |||||
dst=dst, | |||||
) | |||||
if my_rank != dst: | |||||
return | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(oneflow.uint8) # type: ignore[call-overload] | |||||
tensor_size = object_size_list[i] | |||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
def _object_to_tensor(obj, device=None): | |||||
f = io.BytesIO() | |||||
obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} | |||||
pickled_bytes = pickle.dumps(obj) | |||||
byte_tensor = oneflow.ByteTensor(list(pickled_bytes)) | |||||
local_size = oneflow.LongTensor([byte_tensor.numel()]) | |||||
if device is not None: | |||||
byte_tensor = byte_tensor.to(device) | |||||
local_size = local_size.to(device) | |||||
return byte_tensor, local_size | |||||
def _tensor_to_object(tensor, tensor_size): | |||||
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] | |||||
res = pickle.loads(buf) | |||||
assert res["protocol_version"] == PROTOCOL_VERSION | |||||
return res["data"] | |||||
def send_recv_object(obj, src, cur_rank, device): | |||||
r""" | |||||
oneflow 中的单点对多点的分发函数; | |||||
例如将进程 0 上的对象 object 分发到其它进程上; | |||||
Example:: | |||||
cur_rank = int(os.environ.get('LOCAL_RANK', 0)) | |||||
# 拿到 local_device | |||||
send_recv_object(object, 0, cur_rank, local_device) | |||||
:param obj: 一个可以序列化的 python 对象; | |||||
:param src: 从哪一个 rank 上发送到其它 rank; | |||||
:param cur_rank: 当前的进程的 rank 序号; | |||||
:param device: 当前的进程所在的设备; | |||||
:param group: 通信组,默认为 None; | |||||
:param tag: 将发送与远程接收匹配的标记; | |||||
:return: | |||||
""" | |||||
# src rank send to all other ranks | |||||
size = oneflow.LongTensor([0]).to(device) | |||||
if cur_rank == src: | |||||
world_size = dist_env.get_world_size() | |||||
tensor, size = _object_to_tensor(obj) | |||||
tensor = tensor.to(device) | |||||
size = size.to(device) | |||||
# 首先同步 obj 的 size 的信息; | |||||
comm.broadcast(size, src) | |||||
for subrank in range(world_size): | |||||
if subrank != src: | |||||
comm.send(tensor=tensor, dst=subrank) | |||||
else: | |||||
comm.broadcast(size, src) | |||||
tensor = oneflow.ByteTensor([0] * size).to(device) | |||||
comm.recv(tensor=tensor, src=src) | |||||
return _tensor_to_object(tensor.cpu(), size) | |||||
def _to_device(tensor, device): | |||||
return tensor.contiguous().to(device) | |||||
def fastnlp_oneflow_all_gather(obj: Any, device=None) ->List: | |||||
""" | |||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||||
example:: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||||
序列化之后进行传输。 | |||||
:param device: 当前该参数无意义。 | |||||
:param group: | |||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
return [obj] | |||||
if isinstance(obj, oneflow.Tensor): | |||||
objs = [oneflow.zeros_like(obj) for _ in range(dist_env.get_world_size())] | |||||
comm.all_gather(objs, obj) | |||||
else: | |||||
objs = [None for _ in range(dist_env.get_world_size())] | |||||
# 防止 unpickle 的时候弄到发送的 gpu 上了 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
all_gather_object(objs, obj) | |||||
return objs | |||||
def fastnlp_oneflow_broadcast_object(obj, src, device=None): | |||||
""" | |||||
将 src 上的 obj 对象广播到其它 rank 上。 | |||||
:param obj: 需要发送的对象 | |||||
:param src: 从哪里发出。 | |||||
:param device: | |||||
:param group: 属于哪个通信 group | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
if src == dist_env.get_rank(): | |||||
return obj | |||||
else: | |||||
return None | |||||
cur_rank = dist_env.get_rank() | |||||
if cur_rank == src: | |||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
if device is None: | |||||
device = oneflow.cuda.current_device() | |||||
device = get_oneflow_device(device) | |||||
if cur_rank == src: | |||||
tensor, size = _object_to_tensor(obj, device=device) | |||||
else: | |||||
size = oneflow.LongTensor([0]).to(device) | |||||
comm.broadcast(size, src=src) | |||||
if cur_rank != src: | |||||
tensor = oneflow.empty( | |||||
size.int().item(), # type: ignore[arg-type] | |||||
dtype=oneflow.uint8, | |||||
device=device | |||||
) | |||||
comm.broadcast(tensor, src=src) | |||||
return _tensor_to_object(tensor, tensor_size=size.item()) | |||||
def all_gather_object(object_list, obj): | |||||
""" | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
:param object_list: | |||||
:param obj: | |||||
:param group: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
return [obj] | |||||
current_device = get_oneflow_device(oneflow.cuda.current_device()) | |||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist_env.get_world_size() | |||||
object_sizes_tensor = oneflow.zeros( | |||||
group_size, dtype=oneflow.long, device=current_device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes | |||||
comm.all_gather(object_size_list, local_size) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor = input_tensor.reshape(max_object_size) | |||||
coalesced_output_tensor = oneflow.empty( | |||||
max_object_size * group_size, dtype=oneflow.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
comm.all_gather(output_tensors, input_tensor) | |||||
# Deserialize outputs back to object. | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(oneflow.uint8) | |||||
if tensor.device != oneflow.device("cpu"): | |||||
tensor = tensor.cpu() | |||||
tensor_size = object_size_list[i] | |||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
return object_list |
@@ -0,0 +1,70 @@ | |||||
import os | |||||
from typing import Optional, Union, List, Sequence | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from .oneflow_driver import OneflowDriver | |||||
from .single_device import OneflowSingleDriver | |||||
from .ddp import OneflowDDPDriver | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | |||||
__all__ = [] | |||||
def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow.device", int, List[int]]], | |||||
model: "oneflow.nn.Module", **kwargs) -> OneflowDriver: | |||||
r""" | |||||
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; | |||||
:param driver: 该参数的值应为以下之一:``["oneflow"]``; | |||||
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; | |||||
:param model: 训练或者评测的具体的模型; | |||||
:return: 返回一个 :class:`~fastNLP.core.OneflowSingleDriver` 或 :class:`~fastNLP.core.OneflowDDPDriver` 实例; | |||||
""" | |||||
# world_size 和 rank | |||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | |||||
if device is not None: | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `oneflow.distributed.launch` to pull " | |||||
"up your script. ", once=True) | |||||
return OneflowDDPDriver(model, None, **kwargs) | |||||
if driver not in {"oneflow"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['oneflow'].") | |||||
_could_use_device_num = oneflow.cuda.device_count() | |||||
if isinstance(device, str): | |||||
device = oneflow.device(device) | |||||
elif isinstance(device, int): | |||||
if device < 0: | |||||
if device != -1: | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||||
device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] | |||||
elif device >= _could_use_device_num: | |||||
print(device, _could_use_device_num) | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
else: | |||||
device = oneflow.device(f"cuda:{device}") | |||||
elif isinstance(device, Sequence): | |||||
device = list(set(device)) | |||||
for each in device: | |||||
if not isinstance(each, int): | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | |||||
elif each < 0: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | |||||
elif each >= _could_use_device_num: | |||||
raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||||
f" the available gpu number:{_could_use_device_num}.") | |||||
device = [oneflow.device(f"cuda:{w}") for w in device] | |||||
elif device is not None and not isinstance(device, oneflow.device): | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
if driver == "oneflow": # single, ddp, 直接启动。 | |||||
if not isinstance(device, List): | |||||
return OneflowSingleDriver(model, device, **kwargs) | |||||
else: | |||||
raise RuntimeError("If you want to run distributed training, please use " | |||||
"'python -m oneflow.distributed.launch xxx.py'.") | |||||
return OneflowDDPDriver(model, device, **kwargs) |
@@ -0,0 +1,445 @@ | |||||
import os | |||||
from typing import Union, Dict, Optional, Callable, Tuple | |||||
from functools import partial | |||||
import numpy as np | |||||
import random | |||||
from dataclasses import dataclass | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from pathlib import Path | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import DataLoader, Sampler, BatchSampler, Dataset | |||||
from oneflow.optim import Optimizer | |||||
from oneflow.utils.data import RandomSampler as OneflowRandomSampler | |||||
_reduces = { | |||||
"sum": oneflow.sum, | |||||
"min": oneflow.min, | |||||
"max": oneflow.max, | |||||
"mean": oneflow.mean | |||||
} | |||||
__all__ = [ | |||||
"OneflowDriver" | |||||
] | |||||
from .utils import optimizer_state_to_device, DummyGradScaler | |||||
from fastNLP.core.drivers.driver import Driver | |||||
from fastNLP.core.utils.utils import _get_fun_msg, nullcontext | |||||
from fastNLP.core.utils import apply_to_collection, oneflow_move_data_to_device, auto_param_call | |||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class OneflowDriver(Driver): | |||||
r""" | |||||
专属于 ``oneflow`` 的 ``driver``,是 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 的父类; | |||||
.. warning:: | |||||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver``,而不是 | |||||
该类本身; | |||||
.. note:: | |||||
您可以在使用 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 时使用 ``OneflowDriver`` 提供的接口; | |||||
""" | |||||
def __init__(self, model, fp16: Optional[bool] = False, oneflow_kwargs: Dict = None, **kwargs): | |||||
super(OneflowDriver, self).__init__(model) | |||||
""" 进行 fp16 的设置 """ | |||||
self._oneflow_kwargs = oneflow_kwargs if oneflow_kwargs is not None else {} | |||||
self.fp16 = fp16 | |||||
if fp16: | |||||
logger.warn("OneflowDriver of eager mode dose not support fp16 now.``") | |||||
# self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16) | |||||
# self.grad_scaler = _grad_scaler(**self._oneflow_kwargs.get("gradscaler_kwargs", {})) | |||||
self.auto_cast = nullcontext | |||||
self.grad_scaler = DummyGradScaler() | |||||
self.set_grad_to_none = self._oneflow_kwargs.get("set_grad_to_none") | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad(self.set_grad_to_none) | |||||
def backward(self, loss): | |||||
loss.backward() | |||||
# self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def check_dataloader_legality(self, dataloader): | |||||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||||
if len(dataloader) == 0: | |||||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | |||||
"may cause some unexpected exceptions.", once=True) | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, Optimizer): | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | |||||
@staticmethod | |||||
def tensor_to_numeric(tensor, reduce: str = None): | |||||
r""" | |||||
将 ``oneflow.Tensor`` 转换成 python 中的数值类型; | |||||
:param tensor: ``oneflow.Tensor``; | |||||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||||
""" | |||||
if tensor is None: | |||||
return None | |||||
def _translate(_data): | |||||
if _data.numel() == 1: | |||||
return _data.item() | |||||
if reduce is None: | |||||
return _data.tolist() | |||||
return _reduces[reduce](_data).item() | |||||
return apply_to_collection( | |||||
data=tensor, | |||||
dtype=oneflow.Tensor, | |||||
function=_translate | |||||
) | |||||
def set_model_mode(self, mode: str): | |||||
r""" | |||||
设置模型的状态是 ``train`` 还是 ``eval``; | |||||
:param mode: ``'train'`` 或 ``'eval'``; | |||||
""" | |||||
assert mode in {"train", "eval"} | |||||
getattr(self.model, mode)() | |||||
@rank_zero_call | |||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
保存当前 driver 的模型到 folder 下。 | |||||
:param filepath: 保存到哪个文件夹; | |||||
:param only_state_dict: 是否只保存权重;如果使用 ``DistributedDataParallel`` 启动分布式训练的话,该参数只能为 ``True``; | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
if not only_state_dict and self.is_distributed(): | |||||
logger.warn("`Cannot save ddp model directly, we will save its state_dict for you.") | |||||
only_state_dict = True | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
oneflow.save(states, filepath) | |||||
else: | |||||
if self.model_device is not None: | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, oneflow.device("cpu")) | |||||
oneflow.save(model, filepath) | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, self.model_device) | |||||
else: | |||||
oneflow.save(model, filepath) | |||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
从 folder 中加载权重并赋值到当前 driver 的模型上。 | |||||
:param filepath: 加载权重或模型的路径 | |||||
:param load_state_dict: 保存的内容是否只是权重。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
res = oneflow.load(filepath) | |||||
if isinstance(res, dict) and only_state_dict is False: | |||||
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | |||||
f"`only_state_dict=True`") | |||||
elif not isinstance(res, dict) and only_state_dict is True: | |||||
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | |||||
f"`only_state_dict=False`") | |||||
if not isinstance(res, dict): | |||||
res = res.state_dict() | |||||
model.load_state_dict(res) | |||||
@rank_zero_call | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态; | |||||
num_consumed_batches = states.pop("num_consumed_batches") | |||||
states["sampler_states"] = self.get_sampler_state(dataloader, num_consumed_batches) | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
self.save_model(model_path, only_state_dict=only_state_dict) | |||||
# 3. 保存 optimizers 的状态; | |||||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||||
logger.debug("Save optimizer state dict.") | |||||
# # 4. 保存fp16的状态 | |||||
# if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# grad_scaler_state_dict = self.grad_scaler.state_dict() | |||||
# states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||||
oneflow.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
def get_sampler_state(self, dataloader, num_consumed_batches): | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||||
sampler_states = sampler.state_dict() | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states["num_consumed_samples"] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||||
"`num_consumed_samples`, it may cause missing some samples when reload.") | |||||
else: | |||||
raise RuntimeError("The sampler has no `state_dict()` method, fastNLP cannot save the training " | |||||
"state.") | |||||
return sampler_states | |||||
def load_sampler_state(self, dataloader, sampler_states): | |||||
states = {} | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif isinstance(dataloader_args.sampler, OneflowRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace oneflow RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | |||||
else: | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(sampler_states) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | |||||
def get_optimizer_state(self): | |||||
optimizers_state_dict = {} | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: oneflow.optim.Optimizer = self.optimizers[i] | |||||
optimizer_state = optimizer.state_dict() | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], oneflow.device("cpu")) | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
return optimizers_state_dict | |||||
def load_optimizer_state(self, states): | |||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
f"checkpoint it is:{len(states)}" | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: oneflow.optim.Optimizer = self.optimizers[i] | |||||
optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
logger.debug("Load optimizer state dict.") | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = oneflow.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
# 1. 加载 optimizers 的状态; | |||||
optimizers_state_dict = states.pop("optimizers_state_dict") | |||||
self.load_optimizer_state(optimizers_state_dict) | |||||
# 2. 加载模型状态; | |||||
if should_load_model: | |||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||||
# # 3. 加载 fp16 的状态 | |||||
# if "grad_scaler_state_dict" in states: | |||||
# grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | |||||
# if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||||
# logger.debug("Load grad_scaler state dict...") | |||||
# elif not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||||
# f"the training process may be unstable.") | |||||
# 4. 恢复 sampler 的状态; | |||||
sampler_states = states.pop("sampler_states") | |||||
states_ret = self.load_sampler_state(dataloader, sampler_states) | |||||
states.update(states_ret) | |||||
return states | |||||
def get_evaluate_context(self): | |||||
r""" | |||||
:return: 返回 ``oneflow.no_grad`` 这个 context; | |||||
""" | |||||
return oneflow.no_grad | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return fn(batch) | |||||
def get_model_call_fn(self, fn: str) -> Tuple: | |||||
if hasattr(self.model, fn): | |||||
fn = getattr(self.model, fn) | |||||
if not callable(fn): | |||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
logger.debug(f"Use {_get_fun_msg(fn, with_fp=False)}...") | |||||
return fn, None | |||||
elif fn in {"train_step", "evaluate_step"}: | |||||
logger.debug(f"Use {_get_fun_msg(self.model.forward, with_fp=False)}...") | |||||
return self.model, self.model.forward | |||||
else: | |||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
@staticmethod | |||||
def move_model_to_device(model: "oneflow.nn.Module", device: "oneflow.device"): | |||||
r""" | |||||
将模型迁移到对应的设备上; | |||||
""" | |||||
if device is not None: | |||||
model.to(device) | |||||
def move_data_to_device(self, batch): | |||||
""" | |||||
将一个 batch 的数据迁移到对应的设备上; | |||||
:param batch: 一个 batch 的数据,可以是 ``list、dict`` 等; | |||||
:return: | |||||
""" | |||||
return oneflow_move_data_to_device(batch, self.data_device) | |||||
@staticmethod | |||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
process_seed = oneflow.initial_seed() | |||||
base_seed = process_seed - worker_id | |||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
np.random.seed(ss.generate_state(4)) | |||||
oneflow_ss, stdlib_ss = ss.spawn(2) | |||||
oneflow.manual_seed(oneflow_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
random.seed(stdlib_seed) | |||||
def set_deterministic_dataloader(self, dataloader: "DataLoader"): | |||||
if dataloader.worker_init_fn is None: | |||||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int): | |||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
if callable(getattr(dataloader.sampler, "set_epoch", None)): | |||||
dataloader.sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# oneflow 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
# dataloader 使用的是 sampler; | |||||
if dataloader.batch_sampler is None: | |||||
res.sampler = dataloader.sampler | |||||
res.batch_size = 1 | |||||
res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False | |||||
res.drop_last = False | |||||
# dataloader 使用的是 batch_sampler; | |||||
else: | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, OneflowRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, OneflowRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
# 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性; | |||||
# 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上; | |||||
res.sampler = dataloader.sampler | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -0,0 +1,114 @@ | |||||
import os | |||||
from typing import Dict, Union | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import SequentialSampler as OneflowSequentialSampler | |||||
from oneflow.utils.data import BatchSampler as OneflowBatchSampler | |||||
__all__ = [ | |||||
"OneflowSingleDriver" | |||||
] | |||||
from .oneflow_driver import OneflowDriver | |||||
from fastNLP.core.drivers.oneflow_driver.utils import replace_sampler, replace_batch_sampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.core.log import logger | |||||
class OneflowSingleDriver(OneflowDriver): | |||||
r""" | |||||
用于执行 ``oneflow`` 动态图 cpu 和 单卡 gpu 运算的 ``driver``; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param device: oneflow.device,当前进程所使用的设备; | |||||
:param fp16: 是否开启 fp16;目前动态图的单卡下该参数无效; | |||||
:param oneflow_kwargs: | |||||
""" | |||||
def __init__(self, model, device: "oneflow.device", fp16: bool = False, oneflow_kwargs: Dict = None, **kwargs): | |||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |||||
if cuda_visible_devices == "": | |||||
device = oneflow.device("cpu") | |||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | |||||
"use `cpu` instead of `gpu` device.") | |||||
super(OneflowSingleDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs) | |||||
if device is None: | |||||
logger.debug("device is not set, fastNLP will try to automatically get it.") | |||||
try: | |||||
device = next(model.parameters()).device | |||||
assert isinstance(device, oneflow.device) | |||||
except: | |||||
raise ValueError("fastNLP cannot get device automatically, please set device explicitly.") | |||||
self.model_device = device | |||||
self.local_rank = 0 | |||||
self.global_rank = 0 | |||||
self.world_size = 1 | |||||
def setup(self): | |||||
r""" | |||||
将模型迁移到相应的设备上; | |||||
""" | |||||
if self.model_device is not None: | |||||
self.model.to(self.model_device) | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | |||||
if type(args.batch_sampler) is OneflowBatchSampler: | |||||
if type(args.sampler) is OneflowSequentialSampler: | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace oneflow SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
return dataloader | |||||
def unwrap_model(self): | |||||
r""" | |||||
:return: 返回模型 | |||||
""" | |||||
return self.model | |||||
@property | |||||
def data_device(self): | |||||
r""" | |||||
:return: 数据和模型所在的设备; | |||||
""" | |||||
return self.model_device | |||||
def is_distributed(self): | |||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,在 ``OneflowSingleDriver`` 中返回 ``False``; | |||||
""" | |||||
return False |
@@ -0,0 +1,292 @@ | |||||
import os | |||||
from typing import Any, Dict, Optional | |||||
from enum import IntEnum | |||||
import contextlib | |||||
import random | |||||
import numpy as np | |||||
import inspect | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.envs.utils import get_global_seed | |||||
from fastNLP.envs import ( | |||||
get_global_rank, | |||||
FASTNLP_BACKEND_LAUNCH, | |||||
FASTNLP_GLOBAL_SEED, | |||||
) | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.core.log import logger | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.nn import Module | |||||
from oneflow.utils.data import DataLoader | |||||
from oneflow.utils.data import RandomSampler as oneflowRandomSampler | |||||
from oneflow.utils.data import SequentialSampler as oneflowSequentialSampler | |||||
from oneflow.utils.data import BatchSampler as oneflowBatchSampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
__all__ = [ | |||||
'oneflow_seed_everything', | |||||
'optimizer_state_to_device' | |||||
] | |||||
def oneflow_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: | |||||
r""" | |||||
为 **oneflow**、**numpy**、**python.random** 伪随机数生成器设置种子。 | |||||
:param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 | |||||
:param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 | |||||
当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 | |||||
""" | |||||
max_seed_value = np.iinfo(np.uint32).max | |||||
min_seed_value = np.iinfo(np.uint32).min | |||||
if seed is None: | |||||
if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": | |||||
seed = 42 | |||||
else: | |||||
seed = get_global_seed() | |||||
logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") | |||||
if not isinstance(seed, int): | |||||
seed = int(seed) | |||||
if not (min_seed_value <= seed <= max_seed_value): | |||||
logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") | |||||
seed %= max_seed_value | |||||
os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" | |||||
if add_global_rank_to_seed: | |||||
seed += get_global_rank() | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
oneflow.manual_seed(seed) | |||||
oneflow.cuda.manual_seed_all(seed) | |||||
return seed | |||||
class ForwardState(IntEnum): | |||||
TRAIN = 0 | |||||
VALIDATE = 1 | |||||
TEST = 2 | |||||
PREDICT = 3 | |||||
class _DDPWrappingModel(Module): | |||||
""" | |||||
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | |||||
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | |||||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; | |||||
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | |||||
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | |||||
因此出于以上考虑,我们实现了这一函数; | |||||
对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计; | |||||
""" | |||||
def __init__(self, model: Module): | |||||
super(_DDPWrappingModel, self).__init__() | |||||
self.model = model | |||||
def forward(self, batch, **kwargs) -> Dict: | |||||
""" | |||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||||
""" | |||||
fn = kwargs.pop("fastnlp_fn") | |||||
signature_fn = kwargs.pop("fastnlp_signature_fn") | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return fn(batch) | |||||
class DummyGradScaler: | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
return | |||||
if dummy: | |||||
autocast = contextlib.ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
if not oneflow.cuda.is_available(): | |||||
raise RuntimeError("Oneflow is not installed in gpu version, please use device='cpu'.") | |||||
if oneflow.cuda.get_device_capability(0)[0] < 7: | |||||
logger.rank_zero_warning( | |||||
"NOTE: your device does NOT support faster training with fp16, " | |||||
"please switch to FP32 which is likely to be faster" | |||||
) | |||||
try: | |||||
from oneflow.amp import GradScaler | |||||
from oneflow.cuda.amp import autocast, GradScaler | |||||
except ImportError: | |||||
raise RuntimeError("torch version too low (less than 1.6)") | |||||
return autocast, GradScaler | |||||
def replace_sampler(dataloader: "DataLoader", sampler): | |||||
r""" | |||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader | |||||
的类,而不是直接的 DataLoader; | |||||
如果需要定制自己的 dataloader,保证以下两点: | |||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; | |||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 | |||||
来获取实际的参数的值; | |||||
""" | |||||
# 拿到实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||||
# 'multiprocessing_context' 是 user-defined function; | |||||
if getattr(dataloader, 'multiprocessing_context', None) is not None: | |||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
# 防止用户的 DataLoader 是继承了 oneflow 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if has_variadic_kwargs and isinstance(dataloader, DataLoader): | |||||
# 防止用户写入了 super().__init__(**kwargs) | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
if isinstance(dataloader, DataLoader): | |||||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||||
required_args = { | |||||
p.name | |||||
for p in init_params.values() | |||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) | |||||
and p.default is p.empty | |||||
and p.name not in reconstruct_args | |||||
} | |||||
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 | |||||
if required_args: | |||||
required_args = sorted(required_args) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||||
f"`{dataloader_self_name}`'s attribute." | |||||
) | |||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||||
if not has_variadic_kwargs: | |||||
# the dataloader signature does not allow keyword arguments that need to be passed | |||||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | |||||
if missing_kwargs: | |||||
missing_kwargs = sorted(missing_kwargs) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||||
) | |||||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||||
if not isinstance(dataloader, DataLoader): | |||||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||||
return type(dataloader)(**reconstruct_args) | |||||
def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
r""" | |||||
替换一个 dataloader 的 batch_sampler; | |||||
""" | |||||
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | |||||
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | |||||
if k in params_keys: | |||||
params_keys.remove(k) | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | |||||
params["batch_sampler"] = new_batch_sampler | |||||
if not isinstance(dataloader, DataLoader): | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if not has_variadic_kwargs: | |||||
params = {key:value for key,value in params.items() if key in init_params} | |||||
return type(dataloader)(**params) | |||||
def optimizer_state_to_device(state, device): | |||||
r""" | |||||
将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备; | |||||
:param state: ``optimzier.state_dict()``; | |||||
:param device: 要迁移到的目的设备; | |||||
:return: 返回迁移后的新的 state_dict; | |||||
""" | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, oneflow.Tensor): | |||||
new_state[name] = param.to(device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state | |||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler, | |||||
oneflowSequentialSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
f"``{substitution}``. The customized sampler should set for distributed running " | |||||
f"before initializing ``{controller}`` , and then set the " | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -73,6 +73,7 @@ from .utils import ( | |||||
_FleetWrappingModel, | _FleetWrappingModel, | ||||
replace_sampler, | replace_sampler, | ||||
replace_batch_sampler, | replace_batch_sampler, | ||||
_check_dataloader_args_for_distributed | |||||
) | ) | ||||
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | ||||
@@ -129,15 +130,15 @@ class PaddleFleetDriver(PaddleDriver): | |||||
:param is_pull_by_paddle_run: 标记当前进程是否为通过 ``python -m paddle.distributed.launch`` 启动的。 | :param is_pull_by_paddle_run: 标记当前进程是否为通过 ``python -m paddle.distributed.launch`` 启动的。 | ||||
这个参数仅在 :class:`~fastNLP.core.Trainer` 中初始化 driver 时使用 | 这个参数仅在 :class:`~fastNLP.core.Trainer` 中初始化 driver 时使用 | ||||
:param fp16: 是否开启混合精度训练; | :param fp16: 是否开启混合精度训练; | ||||
:param paddle_kwargs: | |||||
* *fleet_kwargs* -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | |||||
* *is_collective* -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; | |||||
* *role_maker* -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker``; | |||||
* 其它用于初始化 ``DataParallel`` 的参数; | |||||
* *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数; | |||||
:kwargs: | :kwargs: | ||||
* *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: | |||||
* fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | |||||
* is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; | |||||
* role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` | |||||
* 其它用于初始化 ``DataParallel`` 的参数; | |||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | * wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | ||||
.. note:: | .. note:: | ||||
@@ -151,11 +152,12 @@ class PaddleFleetDriver(PaddleDriver): | |||||
parallel_device: Optional[Union[List[str], str]], | parallel_device: Optional[Union[List[str], str]], | ||||
is_pull_by_paddle_run: bool = False, | is_pull_by_paddle_run: bool = False, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
paddle_kwargs: Dict = None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
if USER_CUDA_VISIBLE_DEVICES not in os.environ: | if USER_CUDA_VISIBLE_DEVICES not in os.environ: | ||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | |||||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.") | |||||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs) | |||||
# 如果不是通过 launch 启动,要求用户必须传入 parallel_device | # 如果不是通过 launch 启动,要求用户必须传入 parallel_device | ||||
if not is_pull_by_paddle_run: | if not is_pull_by_paddle_run: | ||||
@@ -193,17 +195,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.world_size = None | self.world_size = None | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self.gloo_rendezvous_dir = None | self.gloo_rendezvous_dir = None | ||||
# 分布式环境的其它参数设置 | |||||
paddle_kwargs = kwargs.get("paddle_kwargs", {}) | |||||
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) | |||||
self._fleet_kwargs = self._paddle_kwargs.get("fleet_kwargs", {}) | |||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__) | ||||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | ||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | ||||
self.is_collective = self._fleet_kwargs.pop("is_collective", True) | self.is_collective = self._fleet_kwargs.pop("is_collective", True) | ||||
if not self.is_collective: | if not self.is_collective: | ||||
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | |||||
raise NotImplementedError("fastNLP only support `collective` for distributed training now.") | |||||
self.role_maker = self._fleet_kwargs.pop("role_maker", None) | self.role_maker = self._fleet_kwargs.pop("role_maker", None) | ||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | ||||
@@ -422,8 +421,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||||
"control.") | |||||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||||
else: | else: | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | if isinstance(args.batch_sampler, ReproducibleBatchSampler): | ||||
@@ -454,6 +452,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
_check_dataloader_args_for_distributed(args, controller='Trainer') | |||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=args.dataset, | dataset=args.dataset, | ||||
shuffle=args.shuffle, | shuffle=args.shuffle, | ||||
@@ -38,7 +38,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
if is_in_paddle_launch_dist(): | if is_in_paddle_launch_dist(): | ||||
if user_visible_devices is None: | if user_visible_devices is None: | ||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | |||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.") | |||||
if device is not None: | if device is not None: | ||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | ||||
"up your script. And we will directly get the local device via environment variables.", once=True) | "up your script. And we will directly get the local device via environment variables.", once=True) | ||||
@@ -19,6 +19,7 @@ from fastNLP.envs import ( | |||||
rank_zero_call, | rank_zero_call, | ||||
) | ) | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
@@ -55,27 +56,32 @@ class PaddleDriver(Driver): | |||||
1. :class:`~fastNLP.core.drivers.PaddleSingleDriver`:实现了使用单卡和 ``cpu`` 训练的具体功能; | 1. :class:`~fastNLP.core.drivers.PaddleSingleDriver`:实现了使用单卡和 ``cpu`` 训练的具体功能; | ||||
2. :class:`~fastNLP.core.drivers.PaddleFleetDriver`:实现了使用 ``fleet`` 分布式训练 API 进行集群式分布式训练的具体功能; | 2. :class:`~fastNLP.core.drivers.PaddleFleetDriver`:实现了使用 ``fleet`` 分布式训练 API 进行集群式分布式训练的具体功能; | ||||
:param model: 训练时使用的 **PaddlePaddle** 模型; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:kwargs: | |||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | |||||
.. warning:: | |||||
.. note:: | |||||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``PaddleSingleDriver`` 和 ``PaddleDDPDriver``,而不是 | |||||
该类本身; | |||||
关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | |||||
.. note:: | |||||
您可以在使用 ``PaddleSingleDriver`` 和 ``PaddleFleetDriver`` 时使用 ``PaddleDriver`` 提供的接口; | |||||
:param model: 训练时使用的 **PaddlePaddle** 模型; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:param paddle_kwargs: | |||||
""" | """ | ||||
def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs): | |||||
if not isinstance(model, paddle.nn.Layer): | if not isinstance(model, paddle.nn.Layer): | ||||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " | raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " | ||||
f"`paddle.nn.Layer` type.") | f"`paddle.nn.Layer` type.") | ||||
super(PaddleDriver, self).__init__(model) | super(PaddleDriver, self).__init__(model) | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self._paddle_kwargs = paddle_kwargs if paddle_kwargs is not None else {} | |||||
# scaler的参数 | # scaler的参数 | ||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | |||||
self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {})) | |||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
@@ -93,7 +99,7 @@ class PaddleDriver(Driver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
if not isinstance(dataloader, DataLoader): | |||||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | ||||
if dataloader.batch_size is None and dataloader.batch_sampler is None: | if dataloader.batch_size is None and dataloader.batch_sampler is None: | ||||
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | ||||
@@ -154,7 +160,7 @@ class PaddleDriver(Driver): | |||||
:param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save`` | :param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save`` | ||||
函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数; | 函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数; | ||||
:kwargs: | :kwargs: | ||||
* input_spec -- 描述存储模型 ``forward`` 方法的输入; | |||||
* *input_spec* -- 描述存储模型 ``forward`` 方法的输入; | |||||
当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` | 当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` | ||||
进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save>`_; | 进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save>`_; | ||||
""" | """ | ||||
@@ -222,26 +228,12 @@ class PaddleDriver(Driver): | |||||
num_consumed_batches = states.pop("num_consumed_batches") | num_consumed_batches = states.pop("num_consumed_batches") | ||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | ||||
@@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import DataParallel | from paddle import DataParallel | ||||
from paddle.fluid.reader import _DatasetKind | from paddle.fluid.reader import _DatasetKind | ||||
from paddle.io import ( | |||||
RandomSampler as PaddleRandomSampler, | |||||
SequenceSampler as PaddleSequenialSampler, | |||||
BatchSampler as PaddleBatchSampler, | |||||
) | |||||
__all__ = [ | __all__ = [ | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
@@ -38,6 +43,8 @@ class PaddleSingleDriver(PaddleDriver): | |||||
:param model: 训练时使用的 **PaddlePaddle** 模型; | :param model: 训练时使用的 **PaddlePaddle** 模型; | ||||
:param device: 训练使用的设备; | :param device: 训练使用的设备; | ||||
:param fp16: 是否开启混合精度训练; | :param fp16: 是否开启混合精度训练; | ||||
:param paddle_kwargs: | |||||
* *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数; | |||||
:kwargs: | :kwargs: | ||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | * wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | ||||
@@ -46,7 +53,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | ||||
""" | """ | ||||
def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs): | |||||
if isinstance(model, DataParallel): | if isinstance(model, DataParallel): | ||||
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | ||||
@@ -56,7 +63,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | ||||
"use `cpu` instead of `gpu` device.") | "use `cpu` instead of `gpu` device.") | ||||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs) | |||||
if device is None: | if device is None: | ||||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | ||||
@@ -122,19 +129,21 @@ class PaddleSingleDriver(PaddleDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, paddle.io.RandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
if type(args.batch_sampler) is PaddleBatchSampler: | |||||
if type(args.sampler) is PaddleRandomSampler: | |||||
if isinstance(args.sampler, PaddleRandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif type(args.sampler) is PaddleSequenialSampler: | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | batch_sampler = ReproduceBatchSampler( | ||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
@@ -15,6 +15,7 @@ from fastNLP.envs import ( | |||||
FASTNLP_BACKEND_LAUNCH, | FASTNLP_BACKEND_LAUNCH, | ||||
FASTNLP_GLOBAL_SEED, | FASTNLP_GLOBAL_SEED, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.utils import auto_param_call, paddle_to | from fastNLP.core.utils import auto_param_call, paddle_to | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -23,7 +24,7 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import nn | from paddle import nn | ||||
from paddle.nn import Layer | from paddle.nn import Layer | ||||
from paddle.io import DataLoader, BatchSampler | |||||
from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler | |||||
from paddle.amp import auto_cast, GradScaler | from paddle.amp import auto_cast, GradScaler | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | from fastNLP.core.utils.dummy_class import DummyClass as Layer | ||||
@@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): | |||||
"NOTE: your device does NOT support faster training with fp16, " | "NOTE: your device does NOT support faster training with fp16, " | ||||
"please switch to FP32 which is likely to be faster" | "please switch to FP32 which is likely to be faster" | ||||
) | ) | ||||
return auto_cast, GradScaler | |||||
return auto_cast, GradScaler | |||||
def find_free_ports(num): | def find_free_ports(num): | ||||
""" | """ | ||||
@@ -178,23 +179,22 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
# 中寻找;VAR_KEYWORD 代表 **kwargs | # 中寻找;VAR_KEYWORD 代表 **kwargs | ||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
del init_params["self"] | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
non_default_params = {name for name, p in init_params.items() if | non_default_params = {name for name, p in init_params.items() if | ||||
name in instance_attrs and p.default != instance_attrs[name]} | name in instance_attrs and p.default != instance_attrs[name]} | ||||
# add `dataset` as it might have been replaced with `*args` | # add `dataset` as it might have been replaced with `*args` | ||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
# 收集不是默认值的参数和它的值 | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
reconstruct_args.update({ | |||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||||
"persistent_workers": dataloader._persistent_workers, | |||||
}) | |||||
if isinstance(dataloader, DataLoader): | |||||
reconstruct_args.update({ | |||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||||
"persistent_workers": dataloader._persistent_workers, | |||||
}) | |||||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | # POSITIONAL_OR_KEYWORD 代表一般的参数 | ||||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | ||||
@@ -212,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||||
f"The missing attributes are {required_args}. " | |||||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||||
f"`{dataloader_self_name}`'s attribute." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
@@ -226,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||||
f"The missing arguments are {missing_kwargs}. " | |||||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||||
) | ) | ||||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||||
if not isinstance(dataloader, DataLoader): | |||||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
@@ -237,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): | |||||
""" | """ | ||||
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 | 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 | ||||
""" | """ | ||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | new_batch_sampler = deepcopy(dataloader.batch_sampler) | ||||
new_batch_sampler.sampler = new_sampler | new_batch_sampler.sampler = new_sampler | ||||
return replace_batch_sampler(dataloader, new_batch_sampler) | return replace_batch_sampler(dataloader, new_batch_sampler) | ||||
@@ -251,3 +256,14 @@ def optimizer_state_to_device(state, device): | |||||
else: | else: | ||||
new_state[name] = param | new_state[name] = param | ||||
return new_state | return new_state | ||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||||
SequenceSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
f"``{substitution}``. The customized sampler should set for distributed running " | |||||
f"before initializing ``{controller}`` , and then set the " | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'TorchDDPDriver', | 'TorchDDPDriver', | ||||
'TorchSingleDriver', | 'TorchSingleDriver', | ||||
'DeepSpeedDriver', | |||||
'TorchDriver', | 'TorchDriver', | ||||
'torch_seed_everything', | 'torch_seed_everything', | ||||
'optimizer_state_to_device' | 'optimizer_state_to_device' | ||||
@@ -10,6 +11,7 @@ from .ddp import TorchDDPDriver | |||||
# todo 实现 fairscale 后再将 fairscale 导入到这里; | # todo 实现 fairscale 后再将 fairscale 导入到这里; | ||||
from .single_device import TorchSingleDriver | from .single_device import TorchSingleDriver | ||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from .deepspeed import DeepSpeedDriver | |||||
from .utils import torch_seed_everything, optimizer_state_to_device | from .utils import torch_seed_everything, optimizer_state_to_device | ||||
@@ -159,6 +159,7 @@ from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, Unrepeated | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
from .utils import _check_dataloader_args_for_distributed | |||||
class TorchDDPDriver(TorchDriver): | class TorchDDPDriver(TorchDriver): | ||||
@@ -234,7 +235,12 @@ class TorchDDPDriver(TorchDriver): | |||||
:param parallel_device: 用于分布式训练的 ``gpu`` 设备; | :param parallel_device: 用于分布式训练的 ``gpu`` 设备; | ||||
:param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; | :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; | ||||
:param fp16: 是否开启 fp16 训练; | :param fp16: 是否开启 fp16 训练; | ||||
:param kwargs: 其余的一些用于设定 ddp 训练的参数; | |||||
:param torch_kwargs: | |||||
* *ddp_kwargs* -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | |||||
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | |||||
* *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
* *non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
* *gradscaler_kwargs* -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数; | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
@@ -243,11 +249,12 @@ class TorchDDPDriver(TorchDriver): | |||||
parallel_device: Optional[Union[List["torch.device"], "torch.device"]], | parallel_device: Optional[Union[List["torch.device"], "torch.device"]], | ||||
is_pull_by_torch_run: bool = False, | is_pull_by_torch_run: bool = False, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
torch_kwargs: Dict = None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | # 在加入很多东西后,需要注意这里调用 super 函数的位置; | ||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) | |||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " | raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " | ||||
@@ -417,6 +424,7 @@ class TorchDDPDriver(TorchDriver): | |||||
os.environ['MASTER_ADDR'] = self.master_address | os.environ['MASTER_ADDR'] = self.master_address | ||||
os.environ['MASTER_PORT'] = self.master_port | os.environ['MASTER_PORT'] = self.master_port | ||||
os.environ["RANK"] = "0" | |||||
os.environ["LOCAL_RANK"] = str(self.local_rank) | os.environ["LOCAL_RANK"] = str(self.local_rank) | ||||
os.environ["WORLD_SIZE"] = f"{self.world_size}" | os.environ["WORLD_SIZE"] = f"{self.world_size}" | ||||
@@ -429,6 +437,7 @@ class TorchDDPDriver(TorchDriver): | |||||
for rank in range(1, len(self.parallel_device)): | for rank in range(1, len(self.parallel_device)): | ||||
env_copy = os.environ.copy() | env_copy = os.environ.copy() | ||||
env_copy["LOCAL_RANK"] = f"{rank}" | env_copy["LOCAL_RANK"] = f"{rank}" | ||||
env_copy["RANK"] = f"{rank}" | |||||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | ||||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | ||||
@@ -535,8 +544,7 @@ class TorchDDPDriver(TorchDriver): | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||||
"control.") | |||||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||||
else: | else: | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | if isinstance(args.batch_sampler, ReproducibleBatchSampler): | ||||
@@ -565,6 +573,7 @@ class TorchDDPDriver(TorchDriver): | |||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
_check_dataloader_args_for_distributed(args, controller='Trainer') | |||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=args.dataset, | dataset=args.dataset, | ||||
shuffle=args.shuffle, | shuffle=args.shuffle, | ||||
@@ -582,6 +591,7 @@ class TorchDDPDriver(TorchDriver): | |||||
if isinstance(args.sampler, ReproducibleSampler): | if isinstance(args.sampler, ReproducibleSampler): | ||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | ||||
elif not isinstance(args.sampler, UnrepeatedSampler): | elif not isinstance(args.sampler, UnrepeatedSampler): | ||||
_check_dataloader_args_for_distributed(args, controller='Evaluator') | |||||
sampler = UnrepeatedSequentialSampler( | sampler = UnrepeatedSequentialSampler( | ||||
dataset=args.dataset | dataset=args.dataset | ||||
) | ) | ||||
@@ -0,0 +1,445 @@ | |||||
import os | |||||
import argparse | |||||
import logging | |||||
from pathlib import Path | |||||
from typing import Union, Dict, List | |||||
from .torch_driver import TorchDriver | |||||
from .ddp import TorchDDPDriver | |||||
from .utils import _create_default_config, _DeepSpeedWrappingModel | |||||
from fastNLP.core.utils import nullcontext | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import( | |||||
FASTNLP_DISTRIBUTED_CHECK, | |||||
FASTNLP_CHECKPOINT_FILENAME | |||||
) | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.optim import Optimizer | |||||
if _NEED_IMPORT_DEEPSPEED: | |||||
import deepspeed | |||||
from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer | |||||
__all__ = [ | |||||
"DeepSpeedDriver", | |||||
] | |||||
class DeepSpeedDriver(TorchDDPDriver): | |||||
""" | |||||
实现 ``deepspeed`` 分布式训练的 ``Driver``。 | |||||
.. note:: | |||||
您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; | |||||
``DeepSpeedDriver`` 目前支持的三种启动方式: | |||||
1. 用户自己不进行任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程, | |||||
然后 ``DeepSpeedDriver`` 自己通过调用 ``deepspeed.initialize`` 来初始化模型和同心组;(情况 A) | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
... | |||||
driver='deepspeed', | |||||
device=[0, 1] | |||||
) | |||||
trainer.run() | |||||
通过运行 ``python train.py`` 启动; | |||||
2. 用户同样不在 ``Trainer`` 之外初始化 ``deepspeed``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 ``model.initialize`` 来初始化 ``ddp`` 的通信组;(情况 B) | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
... | |||||
driver='deepspeed', | |||||
device=None | |||||
) | |||||
trainer.run() | |||||
通过运行 ``deepspeed train.py`` 启动; | |||||
3. 用户自己在外面初始化 ``deepspeed``,并且通过 ``deepspeed train.py`` 拉起,这时无论是多个进程的拉起和通信组的建立 | |||||
都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``DeepSpeedDriver`` 设置一些必要的属性值;(情况 C) | |||||
.. code-block:: | |||||
import deepspeed | |||||
# 初始化 | |||||
model, _, _, _ = deepspeed.initialize(model, ...) | |||||
trainer = Trainer( | |||||
... | |||||
driver='deepspeed', | |||||
device=None | |||||
) | |||||
trainer.run() | |||||
通过运行 ``deepspeed train.py`` 启动; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param parallel_device: 用于分布式训练的 ``gpu`` 设备; | |||||
:param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; | |||||
:param fp16: 是否开启 fp16 训练; | |||||
:param deepspeed_kwargs: | |||||
* *strategy* -- 使用 ZeRO 优化的策略,默认为 ``deepspeed``;目前仅支持以下值: | |||||
* ``deepspeed`` -- 使用 ZeRO 的第二阶段,等同于 ``deepspeed_stage_2``; | |||||
* ``deepspeed_stage_1`` -- 使用 ZeRO 的第一阶段,仅将 ``optimizer`` 的状态分散到不同设备上; | |||||
* ``deepspeed_stage_2`` -- 使用 ZeRO 的第二阶段,将 ``optimizer`` 和**梯度**分散到不同设备上; | |||||
* ``deepspeed_stage_2_offload`` -- 使用 ZeRO 的第二阶段,并且借助 cpu 的内存来进一步节约显存; | |||||
* ``deepspeed_stage_3`` -- 使用 ZeRO 的第三阶段,将 ``optimizer`` 、**梯度**和**模型**分散到不同设备上; | |||||
* ``deepspeed_stage_3_offload`` -- 使用 ZeRO 的第三阶段,并且借助 cpu 的内存来进一步节约显存; | |||||
* ``deepspeed_stage_3_offload_nvme`` -- 使用 ZeRO 的第三阶段,并且借助 NVMe 硬盘来进一步节约显存; | |||||
* *logging_level* -- ``deepspeed`` 库的日志等级,默认为 **logging.ERROR**; | |||||
* *config* -- ``deepspeed`` 的各项设置;**FastNLP** 允许用户传入自己的设置以增强灵活性,但这会使参数 | |||||
中的 ``optimizer`` 、``strategy`` 、 ``fp16`` 等失效,即当这个参数存在时,**FastNLP** 会用该参数覆盖 | |||||
其它的设置; | |||||
""" | |||||
# TODO fp16 load_config | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: Union[List["torch.device"], "torch.device"], | |||||
is_pull_by_torch_run = False, | |||||
fp16: bool = False, | |||||
deepspeed_kwargs: Dict = None, | |||||
**kwargs | |||||
): | |||||
assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." | |||||
kwargs.pop("torch_kwargs", None) | |||||
self._ds_kwargs = deepspeed_kwargs | |||||
TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=deepspeed_kwargs, **kwargs) | |||||
self.fp16 = fp16 | |||||
# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; | |||||
self.is_pull_by_torch_run = is_pull_by_torch_run | |||||
self.parallel_device = parallel_device | |||||
if not is_pull_by_torch_run and parallel_device is None: | |||||
raise ValueError( | |||||
"Parameter `parallel_device` can not be None when using `TorchDeepSpeedDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | |||||
if is_pull_by_torch_run: | |||||
self.model_device = parallel_device | |||||
else: | |||||
# 我们的 model_device 一定是 torch.device,而不是一个 list; | |||||
self.model_device = parallel_device[self.local_rank] | |||||
# 如果用户自己在外面初始化了 deepspeed; | |||||
self.outside_ddp = False | |||||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||||
"fastnlp_torch_launch_not_ddp" not in os.environ: | |||||
# 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型; | |||||
if not isinstance(model, DeepSpeedEngine): | |||||
raise RuntimeError( | |||||
"It is not allowed to input a normal model instead of `DeepSpeedEngine` when" | |||||
"you initialize the ddp process out of our control.") | |||||
self.outside_ddp = True | |||||
self.config = model.config | |||||
self.model_device = None | |||||
self._data_device = kwargs.get("data_device", None) | |||||
if isinstance(self._data_device, int): | |||||
if self._data_device < 0: | |||||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||||
_could_use_device_num = torch.cuda.device_count() | |||||
if self._data_device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
self._data_device = torch.device(f"cuda:{self._data_device}") | |||||
elif isinstance(self._data_device, str): | |||||
self._data_device = torch.device(self._data_device) | |||||
elif self._data_device is not None and not isinstance(self._data_device, torch.device): | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
self._master_port = None | |||||
# world_size 表示的就是全局的显卡的数量; | |||||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||||
self.global_rank = 0 | |||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||||
self.accumulation_steps = kwargs.get("accumulation_steps", 1) | |||||
# 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 | |||||
train_dl = kwargs.get("train_dataloader", None) | |||||
if train_dl is not None: | |||||
self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size | |||||
else: | |||||
logger.warning("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`" | |||||
"to 1 for deepspeed configuration.") | |||||
self.train_micro_batch_size = 1 | |||||
self.strategy = self._ds_kwargs.get("strategy", "deepspeed") | |||||
deepspeed_logging_level = self._ds_kwargs.get("logging_level", logging.ERROR) | |||||
deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level) | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)): | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or " | |||||
f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.") | |||||
def setup(self): | |||||
r""" | |||||
准备分布式环境,该函数主要做以下两件事情: | |||||
1. 开启多进程,每个 gpu 设备对应单独的一个进程; | |||||
2. 使用 ``deepspeed.initialize`` 包裹模型; | |||||
""" | |||||
if len(self.optimizers) != 1: | |||||
raise ValueError("Multi optimizers is not supported for `DeepSpeedDriver` right now.") | |||||
if self._has_setup: | |||||
return | |||||
self._has_setup = True | |||||
self.setup_config() | |||||
# 如果用户需要使用多机模式,那么一定进入到这里; | |||||
if self.is_pull_by_torch_run: | |||||
if self.outside_ddp: | |||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
else: | |||||
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
self.global_rank = int(os.environ.get("RANK")) | |||||
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
if not dist.is_initialized(): | |||||
deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
# 进入到这里的情况时: | |||||
# dist.is_initialized 一定为 False; | |||||
# 一定是单机; | |||||
# self.parallel_device 一定是 List[torch.device]; | |||||
else: | |||||
if not dist.is_initialized(): | |||||
# 这里主要的问题在于要区分 rank0 和其它 rank 的情况; | |||||
self.world_size = len(self.parallel_device) | |||||
self.open_subprocess() | |||||
self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; | |||||
deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 DeepSpeedDriver; | |||||
else: | |||||
# 如果 `dist.is_initialized() == True`,那么说明 DeepSpeedDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在 | |||||
# 使用的(即之后的)DeepSpeedDriver 的设置和第一个 DeepSpeedDriver 是完全一样的; | |||||
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | |||||
if pre_num_processes != len(self.parallel_device): | |||||
raise RuntimeError( | |||||
"Notice you are using `DeepSpeedDriver` after one instantiated `DeepSpeedDriver`, it is not" | |||||
"allowed that your second `DeepSpeedDriver` has a new setting of parameters " | |||||
"`num_nodes` and `num_processes`.") | |||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
if not self.outside_ddp: | |||||
torch.cuda.set_device(self.model_device) | |||||
# 不加 dist.broadcast_object_list 会发生设备在 4,5 但是模型会同步到 0,1 的情况 | |||||
# 原因未知 | |||||
dist.broadcast_object_list(["test"], 0, None) | |||||
self.configure_ddp() | |||||
self.barrier() | |||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||||
self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())] | |||||
dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device)) | |||||
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||||
if local_world_size is None: | |||||
local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device) | |||||
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) | |||||
local_world_size = local_world_size.tolist() + 1 | |||||
node_rank = self.global_rank // local_world_size | |||||
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||||
self._pids = self.tensor_to_numeric(self._pids) | |||||
def configure_ddp(self): | |||||
# 设置 deepspeed | |||||
if not isinstance(self.model, DeepSpeedEngine): | |||||
model=_DeepSpeedWrappingModel(self.model, self.fp16) | |||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |||||
self.model, ds_optimizer, _, _ = deepspeed.initialize( | |||||
args=argparse.Namespace(device_rank=self.model_device.index), | |||||
model=model, | |||||
optimizer=self.optimizers[0], | |||||
model_parameters=model_parameters, | |||||
config=self.config, | |||||
dist_init_required=False | |||||
) | |||||
self._optimizers = [ds_optimizer] | |||||
if self.config.get("activation_checkpointing"): | |||||
checkpoint_config = self.config["activation_checkpointing"] | |||||
deepspeed.checkpointing.configure( | |||||
mpu_=None, | |||||
partition_activations=checkpoint_config.get("partition_activations"), | |||||
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), | |||||
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), | |||||
profile=checkpoint_config.get("profile"), | |||||
) | |||||
self._has_ddpwrapped = True | |||||
def setup_config(self): | |||||
self.config = self._ds_kwargs.get("config") | |||||
if self.config is not None: | |||||
logger.warning("Notice that you have defined a configuration for deepspeed and parameters like" | |||||
"`optimizers`, `strategy` and `fp16` may not take effects.") | |||||
return | |||||
if self.strategy == "deepspeed": | |||||
self.config = _create_default_config(stage=2) | |||||
elif self.strategy == "deepspeed_stage_1": | |||||
self.config = _create_default_config(stage=1) | |||||
elif self.strategy == "deepspeed_stage_2": | |||||
self.config = _create_default_config(stage=2) | |||||
elif self.strategy == "deepspeed_stage_2_offload": | |||||
self.config = _create_default_config(stage=2, offload_optimizer=True) | |||||
elif self.strategy == "deepspeed_stage_3": | |||||
self.config = _create_default_config(stage=3) | |||||
elif self.strategy == "deepspeed_stage_3_offload": | |||||
self.config = _create_default_config( | |||||
stage=3, | |||||
offload_optimizer=True, | |||||
offload_parameters=True, | |||||
) | |||||
elif self.strategy == "deepspeed_stage_3_offload_nvme": | |||||
self.config = _create_default_config( | |||||
stage=3, | |||||
offload_optimizer=True, | |||||
offload_parameters=True, | |||||
remote_device="nvme", | |||||
offload_params_device="nvme", | |||||
offload_optimizer_device="nvme", | |||||
) | |||||
else: | |||||
raise ValueError(f"Unknown deepspeed strategy {self.strategy}.") | |||||
# 设置成 max_int 防止 deepspeed 的输出干扰 fastnlp 的输出 | |||||
self.config.setdefault("steps_per_print", 2147483647) | |||||
self.config["gradient_accumulation_steps"] = self.accumulation_steps | |||||
self.config.setdefault("train_micro_batch_size_per_gpu", self.train_micro_batch_size) | |||||
if self.fp16: | |||||
if "fp16" not in self.config: | |||||
# FP16 is a DeepSpeed standalone AMP implementation | |||||
logger.debug("Enabling DeepSpeed FP16.") | |||||
# TODO 这部分是否可以像 pytorch-lightning 那样给用户定制 | |||||
self.config["fp16"] = { | |||||
"enabled": True, | |||||
"loss_scale": 0, | |||||
"initial_scale_power": True, | |||||
"loss_scale_window": 1000, | |||||
"hysteresis": 2, | |||||
"min_loss_scale": 1, | |||||
} | |||||
elif "amp" not in self.config: | |||||
logger.debug("Enabling DeepSpeed APEX Implementation.") | |||||
self.config["amp"] = {"enabled": True, "opt_level": "O1"} | |||||
def zero_grad(self): | |||||
# DeepSpeedEngine.step 包含了 zero_grad 功能 | |||||
pass | |||||
def backward(self, loss): | |||||
self.model.backward(loss) | |||||
def step(self): | |||||
self.model.step() | |||||
def get_model_no_sync_context(self): | |||||
r""" | |||||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 | |||||
""" | |||||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | |||||
return nullcontext | |||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = False, **kwargs): | |||||
""" | |||||
保存当前 driver 的模型到 folder 下。 | |||||
:param filepath: 保存到哪个文件夹; | |||||
:param only_state_dict: 是否只保存权重;在 ``DeepSpeedDriver`` 中该参数无效; | |||||
:return: | |||||
""" | |||||
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 | |||||
if self.stage_3: | |||||
logger.rank_zero_warning( | |||||
"When saving the DeepSpeed Stage 3 checkpoint, " | |||||
"each worker will save a shard of the checkpoint within a directory. " | |||||
# TODO check一下 | |||||
# "If a single file is required after training, " | |||||
# "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" | |||||
# "deepspeed-zero-stage-3-single-file for instructions." | |||||
) | |||||
if not only_state_dict: | |||||
logger.rank_zero_warning("Only saving state dict is not allowed for `DeepSpeedDriver`. We will save its " | |||||
"checkpoint for you instead.") | |||||
self.model.save_checkpoint(filepath, **kwargs) | |||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = False, **kwargs): | |||||
""" | |||||
从 folder 中加载权重并赋值到当前 driver 的模型上。 | |||||
:param filepath: 加载权重或模型的路径 | |||||
:param load_state_dict: 保存的内容是否只是权重;在 ``DeepSpeedDriver`` 中该参数无效; | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
if not only_state_dict: | |||||
logger.warning("Only loading state dict is not allowed for `DeepSpeedDriver`. We will load its " | |||||
"checkpoint for you instead.") | |||||
self.model.load_checkpoint(filepath, **kwargs) | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 | |||||
# 1. 保存 sampler 的状态 | |||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) | |||||
# 2. 保存模型的状态; | |||||
if not should_save_model: | |||||
logger.rank_zero_warning("Saving checkpoint without model is not allowed for `DeepSpeedDriver`, " | |||||
"so we will still save the model for you.") | |||||
self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME), | |||||
client_state=states) | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
# 1. 加载模型状态; | |||||
if not should_load_model: | |||||
logger.rank_zero_warning("Loading checkpoint without model is not allowed for `DeepSpeedDriver`, " | |||||
"so we will still load the model for you.") | |||||
load_path, states = self.model.load_checkpoint(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
if load_path is None: | |||||
raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}") | |||||
# 2.恢复 sampler 的状态 | |||||
sampler_states = states.pop('sampler_states') | |||||
states_ret = self.load_sampler_state(dataloader, sampler_states) | |||||
states.update(states_ret) | |||||
return states | |||||
@property | |||||
def stage_3(self) -> bool: | |||||
return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3 |
@@ -35,11 +35,12 @@ class FairScaleDriver(TorchDDPDriver): | |||||
parallel_device: Union[List["torch.device"], "torch.device"], | parallel_device: Union[List["torch.device"], "torch.device"], | ||||
is_pull_by_torch_run = False, | is_pull_by_torch_run = False, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
fairscale_kwargs: Dict = None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported." | assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported." | ||||
assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user." | assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user." | ||||
self._fairscale_kwargs = kwargs.get('fairscale_kwargs', {}) | |||||
self._fairscale_kwargs = fairscale_kwargs | |||||
self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp | self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp | ||||
if self.fs_type == 'fsdp': | if self.fs_type == 'fsdp': | ||||
self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True) | self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True) | ||||
@@ -8,6 +8,7 @@ from .torch_driver import TorchDriver | |||||
from .single_device import TorchSingleDriver | from .single_device import TorchSingleDriver | ||||
from .ddp import TorchDDPDriver | from .ddp import TorchDDPDriver | ||||
from .fairscale import FairScaleDriver | from .fairscale import FairScaleDriver | ||||
from .deepspeed import DeepSpeedDriver | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | ||||
from pkg_resources import parse_version | from pkg_resources import parse_version | ||||
@@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
r""" | r""" | ||||
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; | 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; | ||||
:param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``; | |||||
:param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``; | |||||
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; | :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; | ||||
:param model: 训练或者评测的具体的模型; | :param model: 训练或者评测的具体的模型; | ||||
@@ -37,11 +38,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
if driver == 'fairscale': | if driver == 'fairscale': | ||||
return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
elif driver == 'deepspeed': | |||||
return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||||
is_pull_by_torch_run=True, **kwargs) | |||||
else: | else: | ||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
if driver not in {"torch", "fairscale"}: | |||||
if driver not in {"torch", "fairscale", "deepspeed"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | ||||
_could_use_device_num = torch.cuda.device_count() | _could_use_device_num = torch.cuda.device_count() | ||||
@@ -83,4 +87,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | ||||
return FairScaleDriver(model, [device], **kwargs) | return FairScaleDriver(model, [device], **kwargs) | ||||
else: | else: | ||||
return FairScaleDriver(model, device, **kwargs) | |||||
return FairScaleDriver(model, device, **kwargs) | |||||
elif driver == "deepspeed": | |||||
if not isinstance(device, List): | |||||
if device.type == 'cpu': | |||||
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") | |||||
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||||
return DeepSpeedDriver(model, [device], **kwargs) | |||||
else: | |||||
return DeepSpeedDriver(model, device, **kwargs) |
@@ -8,6 +8,7 @@ if _NEED_IMPORT_TORCH: | |||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | from torch.utils.data import RandomSampler as TorchRandomSampler | ||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | from torch.utils.data import SequentialSampler as TorchSequentialSampler | ||||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -34,9 +35,13 @@ class TorchSingleDriver(TorchDriver): | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | :param model: 传入给 ``Trainer`` 的 ``model`` 参数; | ||||
:param device: torch.device,当前进程所使用的设备; | :param device: torch.device,当前进程所使用的设备; | ||||
:param fp16: 是否开启 fp16; | :param fp16: 是否开启 fp16; | ||||
:param torch_kwargs: | |||||
* *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
* *non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
* *gradscaler_kwargs* -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数; | |||||
""" | """ | ||||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | |||||
def __init__(self, model, device: "torch.device", fp16: bool = False, torch_kwargs: Dict = None, **kwargs): | |||||
if isinstance(model, DistributedDataParallel): | if isinstance(model, DistributedDataParallel): | ||||
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | ||||
@@ -46,7 +51,7 @@ class TorchSingleDriver(TorchDriver): | |||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | ||||
"use `cpu` instead of `gpu` device.") | "use `cpu` instead of `gpu` device.") | ||||
super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
super(TorchSingleDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) | |||||
if device is None: | if device is None: | ||||
logger.debug("device is not set, fastNLP will try to automatically get it.") | logger.debug("device is not set, fastNLP will try to automatically get it.") | ||||
@@ -123,19 +128,20 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, TorchRandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
if type(args.batch_sampler) is TorchBatchSampler: | |||||
if type(args.sampler) is TorchRandomSampler: | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif type(args.sampler) is TorchSequentialSampler: | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, TorchSequentialSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | batch_sampler = ReproduceBatchSampler( | ||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -46,12 +47,15 @@ class TorchDriver(Driver): | |||||
您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口; | 您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口; | ||||
:param model: 训练时使用的 **pytorch** 模型; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:param torch_kwargs: | |||||
""" | """ | ||||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model, fp16: Optional[bool] = False, torch_kwargs: Dict = None, **kwargs): | |||||
super(TorchDriver, self).__init__(model) | super(TorchDriver, self).__init__(model) | ||||
""" 进行 fp16 的设置 """ | """ 进行 fp16 的设置 """ | ||||
self._torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
self._torch_kwargs = torch_kwargs if torch_kwargs is not None else {} | |||||
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; | # 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
@@ -92,7 +96,7 @@ class TorchDriver(Driver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
if not isinstance(dataloader, DataLoader): | |||||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | ||||
if len(dataloader) == 0: | if len(dataloader) == 0: | ||||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | ||||
@@ -189,7 +193,30 @@ class TorchDriver(Driver): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | ||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | ||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 1. sampler 的状态; | |||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
self.save_model(model_path, only_state_dict=only_state_dict) | |||||
# 3. 保存 optimizers 的状态; | |||||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||||
logger.debug("Save optimizer state dict.") | |||||
# 4. 保存fp16的状态 | |||||
if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
grad_scaler_state_dict = self.grad_scaler.state_dict() | |||||
states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
def get_sampler_state(self, dataloader, num_consumed_batches): | |||||
# 因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | ||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | ||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
@@ -199,53 +226,58 @@ class TorchDriver(Driver): | |||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
else: | else: | ||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||||
"`num_consumed_samples`, it may cause missing some samples when reload.") | |||||
else: | else: | ||||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | ||||
'state.') | 'state.') | ||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
self.save_model(model_path, only_state_dict=only_state_dict) | |||||
return sampler_states | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = self.get_optimizer_state() | |||||
def load_sampler_state(self, dataloader, sampler_states): | |||||
states = {} | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif isinstance(dataloader_args.sampler, TorchRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | |||||
else: | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(sampler_states) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 保存fp16的状态 | |||||
if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
grad_scaler_state_dict = self.grad_scaler.state_dict() | |||||
states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||||
# 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
logger.debug("Save optimizer state dict") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | |||||
def get_optimizer_state(self): | def get_optimizer_state(self): | ||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
@@ -275,7 +307,7 @@ class TorchDriver(Driver): | |||||
if should_load_model: | if should_load_model: | ||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | ||||
# 3. 加载fp16的状态 | |||||
# 3. 加载 fp16 的状态 | |||||
if "grad_scaler_state_dict" in states: | if "grad_scaler_state_dict" in states: | ||||
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | ||||
if not isinstance(self.grad_scaler, DummyGradScaler): | if not isinstance(self.grad_scaler, DummyGradScaler): | ||||
@@ -286,40 +318,9 @@ class TorchDriver(Driver): | |||||
f"the training process may be unstable.") | f"the training process may be unstable.") | ||||
# 4. 恢复 sampler 的状态; | # 4. 恢复 sampler 的状态; | ||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif isinstance(dataloader_args.sampler, TorchRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | |||||
else: | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states.pop('sampler_states')) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
sampler_states = states.pop('sampler_states') | |||||
states_ret = self.load_sampler_state(dataloader, sampler_states) | |||||
states.update(states_ret) | |||||
return states | return states | ||||
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
from typing import Any, Dict, Optional | |||||
from typing import Any, Dict, Optional, Union | |||||
from enum import IntEnum | from enum import IntEnum | ||||
import contextlib | import contextlib | ||||
import random | import random | ||||
@@ -14,16 +14,19 @@ from fastNLP.envs import ( | |||||
FASTNLP_BACKEND_LAUNCH, | FASTNLP_BACKEND_LAUNCH, | ||||
FASTNLP_GLOBAL_SEED, | FASTNLP_GLOBAL_SEED, | ||||
) | ) | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler | |||||
from fastNLP.core.utils import auto_param_call, apply_to_collection | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
# import torch.nn as nn | # import torch.nn as nn | ||||
from torch.nn import Module | from torch.nn import Module | ||||
from torch.utils.data import DataLoader, BatchSampler | |||||
from torch.utils.data.sampler import Sampler | |||||
from torch.utils.data import DataLoader | |||||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | from fastNLP.core.utils.dummy_class import DummyClass as Module | ||||
@@ -104,6 +107,29 @@ class _DDPWrappingModel(Module): | |||||
else: | else: | ||||
return fn(batch) | return fn(batch) | ||||
class _DeepSpeedWrappingModel(_DDPWrappingModel): | |||||
""" | |||||
继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16 | |||||
""" | |||||
def __init__(self, model: Module, fp16): | |||||
super(_DeepSpeedWrappingModel, self).__init__(model) | |||||
self.fp16 = fp16 | |||||
def forward(self, batch, **kwargs): | |||||
if self.fp16: | |||||
batch = self._move_float_tensors_to_half(batch) | |||||
return super().forward(batch, **kwargs) | |||||
@staticmethod | |||||
def batch_to(data): | |||||
return data.half() | |||||
def _move_float_tensors_to_half(self, batch: Any): | |||||
batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) | |||||
return batch | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
@@ -178,28 +204,33 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
# 'multiprocessing_context' 是 user-defined function; | # 'multiprocessing_context' 是 user-defined function; | ||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
if getattr(dataloader, 'multiprocessing_context', None) is not None: | |||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | # 拿到 dataloader '__init__' 函数的默认函数签名; | ||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||||
# 中寻找; | |||||
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | |||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
del init_params["self"] | |||||
if has_variadic_kwargs and isinstance(dataloader, DataLoader): | |||||
# 防止用户写入了 super().__init__(**kwargs) | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
non_default_params = {name for name, p in init_params.items() if | non_default_params = {name for name, p in init_params.items() if | ||||
name in instance_attrs and p.default != instance_attrs[name]} | name in instance_attrs and p.default != instance_attrs[name]} | ||||
# add `dataset` as it might have been replaced with `*args` | # add `dataset` as it might have been replaced with `*args` | ||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | |||||
if isinstance(dataloader, DataLoader): | |||||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||||
required_args = { | required_args = { | ||||
p.name | p.name | ||||
@@ -209,58 +240,32 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
and p.name not in reconstruct_args | and p.name not in reconstruct_args | ||||
} | } | ||||
# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; | |||||
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 | |||||
if required_args: | if required_args: | ||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||||
f"The missing attributes are {required_args}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||||
f"`{dataloader_self_name}`'s attribute." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
if not has_variadic_kwargs: | if not has_variadic_kwargs: | ||||
# the dataloader signature does not allow keyword arguments that need to be passed | # the dataloader signature does not allow keyword arguments that need to be passed | ||||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | missing_kwargs = reconstruct_args.keys() - init_params.keys() | ||||
if missing_kwargs: | if missing_kwargs: | ||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||||
f"The missing arguments are {missing_kwargs}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | |||||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||||
if not isinstance(dataloader, DataLoader): | |||||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||||
def _dataloader_init_kwargs_resolve_sampler( | |||||
dataloader: "DataLoader", sampler: Optional["Sampler"] | |||||
) -> Dict[str, Any]: | |||||
r""" | |||||
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | |||||
""" | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
# checking the batch sampler type is different than PyTorch default. | |||||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||||
return { | |||||
"sampler": None, | |||||
"shuffle": False, | |||||
"batch_sampler": batch_sampler, | |||||
"batch_size": 1, | |||||
"drop_last": False, | |||||
} | |||||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None} | |||||
return type(dataloader)(**reconstruct_args) | |||||
def replace_batch_sampler(dataloader, new_batch_sampler): | def replace_batch_sampler(dataloader, new_batch_sampler): | ||||
@@ -273,6 +278,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
params_keys.remove(k) | params_keys.remove(k) | ||||
params = {k: getattr(dataloader, k) for k in params_keys} | params = {k: getattr(dataloader, k) for k in params_keys} | ||||
params["batch_sampler"] = new_batch_sampler | params["batch_sampler"] = new_batch_sampler | ||||
if not isinstance(dataloader, DataLoader): | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if not has_variadic_kwargs: | |||||
params = {key:value for key,value in params.items() if key in init_params} | |||||
return type(dataloader)(**params) | return type(dataloader)(**params) | ||||
@@ -295,5 +307,98 @@ def optimizer_state_to_device(state, device): | |||||
return new_state | return new_state | ||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||||
TorchSequentialSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
f"``{substitution}``. The customized sampler should set for distributed running " | |||||
f"before initializing ``{controller}`` , and then set the " | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||||
def _create_default_config( | |||||
zero_optimization: bool = True, | |||||
zero_allow_untested_optimizer: bool = True, | |||||
logging_batch_size_per_gpu: Union[str, int] = "auto", | |||||
partition_activations: bool = False, | |||||
cpu_checkpointing: bool = False, | |||||
contiguous_memory_optimization: bool = False, | |||||
synchronize_checkpoint_boundary: bool = False, | |||||
offload_optimizer: bool = False, | |||||
offload_parameters: bool = False, | |||||
offload_params_device: str = "cpu", | |||||
nvme_path: str = "/local_nvme", | |||||
params_buffer_count: int = 5, | |||||
params_buffer_size: int = 100_000_000, | |||||
max_in_cpu: int = 1_000_000_000, | |||||
offload_optimizer_device: str = "cpu", | |||||
optimizer_buffer_count: int = 4, | |||||
pin_memory: bool = False, | |||||
block_size: int = 1048576, | |||||
queue_depth: int = 8, | |||||
single_submit: bool = False, | |||||
overlap_events: bool = True, | |||||
thread_count: int = 1, | |||||
stage: int = 2, | |||||
contiguous_gradients: bool = True, | |||||
overlap_comm: bool = True, | |||||
allgather_partitions: bool = True, | |||||
reduce_scatter: bool = True, | |||||
allgather_bucket_size: int = 200_000_000, | |||||
reduce_bucket_size: int = 200_000_000, | |||||
sub_group_size: int = 1_000_000_000_000, | |||||
) -> Dict: | |||||
cfg = { | |||||
"activation_checkpointing": { | |||||
"partition_activations": partition_activations, | |||||
"cpu_checkpointing": cpu_checkpointing, | |||||
"contiguous_memory_optimization": contiguous_memory_optimization, | |||||
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, | |||||
}, | |||||
"aio": { | |||||
"block_size": block_size, | |||||
"queue_depth": queue_depth, | |||||
"single_submit": single_submit, | |||||
"overlap_events": overlap_events, | |||||
"thread_count": thread_count, | |||||
}, | |||||
} | |||||
zero_kwargs = { | |||||
"stage": stage, | |||||
"contiguous_gradients": contiguous_gradients, | |||||
"overlap_comm": overlap_comm, | |||||
"allgather_partitions": allgather_partitions, | |||||
"reduce_scatter": reduce_scatter, | |||||
"allgather_bucket_size": allgather_bucket_size, | |||||
"reduce_bucket_size": reduce_bucket_size, | |||||
"sub_group_size": sub_group_size, | |||||
} | |||||
if zero_optimization: | |||||
zero_config = zero_kwargs | |||||
if offload_optimizer: | |||||
zero_config["offload_optimizer"] = { | |||||
"device": offload_optimizer_device, | |||||
"nvme_path": nvme_path, | |||||
"buffer_count": optimizer_buffer_count, | |||||
"pin_memory": pin_memory, | |||||
} | |||||
if offload_parameters: | |||||
zero_config["offload_param"] = { | |||||
"device": offload_params_device, | |||||
"nvme_path": nvme_path, | |||||
"buffer_count": params_buffer_count, | |||||
"buffer_size": params_buffer_size, | |||||
"max_in_cpu": max_in_cpu, | |||||
"pin_memory": pin_memory, | |||||
} | |||||
cfg = { | |||||
"zero_allow_untested_optimizer": zero_allow_untested_optimizer, | |||||
"zero_optimization": zero_config, | |||||
**cfg, | |||||
} | |||||
if logging_batch_size_per_gpu != "auto": | |||||
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} | |||||
return cfg |
@@ -69,7 +69,7 @@ class Accuracy(Metric): | |||||
elif pred.ndim == target.ndim + 1: | elif pred.ndim == target.ndim + 1: | ||||
pred = pred.argmax(axis=-1) | pred = pred.argmax(axis=-1) | ||||
if seq_len is None and target.ndim > 1: | if seq_len is None and target.ndim > 1: | ||||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | else: | ||||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | ||||
@@ -8,6 +8,7 @@ from .backend import Backend | |||||
from .torch_backend.backend import TorchBackend | from .torch_backend.backend import TorchBackend | ||||
from .paddle_backend.backend import PaddleBackend | from .paddle_backend.backend import PaddleBackend | ||||
from .jittor_backend.backend import JittorBackend | from .jittor_backend.backend import JittorBackend | ||||
from .oneflow_backend.backend import OneflowBackend | |||||
class AutoBackend(Backend): | class AutoBackend(Backend): | ||||
@@ -52,6 +53,8 @@ class AutoBackend(Backend): | |||||
self.__class__ = PaddleBackend | self.__class__ = PaddleBackend | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
self.__class__ = JittorBackend | self.__class__ = JittorBackend | ||||
elif backend == 'oneflow': | |||||
self.__class__ = OneflowBackend | |||||
elif backend is None: | elif backend is None: | ||||
# 不用做任何事情就可以初始化了 | # 不用做任何事情就可以初始化了 | ||||
pass | pass | ||||
@@ -0,0 +1,130 @@ | |||||
from typing import List | |||||
import numpy as np | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.metrics.utils import AggregateMethodError | |||||
from fastNLP.core.utils import is_in_oneflow_dist | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
__all__ = [] | |||||
class OneflowBackend(Backend): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._specified = True | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
:param tensor: 需要聚合的张量 | |||||
:param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: | |||||
* method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 | |||||
* method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 | |||||
* method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 | |||||
* method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 | |||||
""" | |||||
if isinstance(tensor, oneflow.Tensor): | |||||
# TODO 暂时没有找到 oneflow 中检测是否初始化了分布式环境的方法 | |||||
if is_in_oneflow_dist(): | |||||
if method is None: | |||||
raise AggregateMethodError(should_have_aggregate_method=True) | |||||
tensor = self.all_gather_object(tensor) | |||||
if isinstance(tensor[0], oneflow.Tensor): | |||||
tensor = oneflow.stack(tensor) | |||||
# 第一步, aggregate结果 | |||||
if method == 'sum': | |||||
tensor = oneflow.sum(tensor, dim=0) | |||||
elif method == 'mean': | |||||
tensor = oneflow.mean(tensor, dim=0) | |||||
elif method == 'max': | |||||
tensor, _ = oneflow.max(tensor, dim=0) | |||||
elif method == 'min': | |||||
tensor, _ = oneflow.min(tensor, dim=0) | |||||
else: | |||||
raise AggregateMethodError(should_have_aggregate_method=False) | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建 tensor,并且填入 value 作为值 | |||||
:param value: 创建张量的初始值 | |||||
""" | |||||
tensor = oneflow.ones(1).fill_(value) | |||||
return tensor | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将 tensor 的值设置为 value | |||||
:param tensor: 传入的张量 | |||||
:param value: 需要 fill 的值。 | |||||
""" | |||||
tensor.fill_(value) | |||||
return tensor | |||||
def get_scalar(self, tensor) -> float: | |||||
""" | |||||
获取 tensor 的 scalar 值 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
return tensor.item() | |||||
def tensor2numpy(self, tensor) -> np.array: | |||||
""" | |||||
将 tensor 转为 numpy 值, 主要是在 metric 计算中使用 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
if isinstance(tensor, oneflow.Tensor): | |||||
return tensor.cpu().detach().numpy() | |||||
elif isinstance(tensor, np.ndarray): | |||||
return tensor | |||||
elif isinstance(tensor, (float, int)): | |||||
return tensor | |||||
else: | |||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||||
@staticmethod | |||||
def is_distributed() -> bool: | |||||
""" | |||||
判断是否为 ddp 状态 | |||||
:return: | |||||
""" | |||||
return is_in_oneflow_dist() | |||||
def move_tensor_to_device(self, tensor, device): | |||||
""" | |||||
将张量移到设备上 | |||||
:param tensor: 需要移动的张量 | |||||
:param device: 设备名, 一般为 "cpu", "cuda:0"等字符串 | |||||
""" | |||||
return tensor.to(device) | |||||
def all_gather_object(self, obj, group=None) -> List: | |||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: | |||||
:param group: | |||||
""" | |||||
if self.is_distributed(): | |||||
obj_list = fastnlp_oneflow_all_gather(obj) | |||||
return obj_list | |||||
return [obj] | |||||
@@ -156,7 +156,7 @@ class ClassifyFPreRecMetric(Metric): | |||||
elif pred.ndim == target.ndim + 1: | elif pred.ndim == target.ndim + 1: | ||||
pred = pred.argmax(axis=-1) | pred = pred.argmax(axis=-1) | ||||
if seq_len is None and target.ndim > 1: | if seq_len is None and target.ndim > 1: | ||||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | else: | ||||
raise RuntimeError(f"when pred have " | raise RuntimeError(f"when pred have " | ||||
f"size:{pred.shape}, target should have size: {pred.shape} or " | f"size:{pred.shape}, target should have size: {pred.shape} or " | ||||
@@ -20,7 +20,7 @@ class Metric: | |||||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | :param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | ||||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | ||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | ||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 | |||||
进行自动设置。 | 进行自动设置。 | ||||
""" | """ | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | ||||
@@ -98,7 +98,7 @@ class Metric: | |||||
return _wrap_get_metric | return _wrap_get_metric | ||||
def __setattr__(self, key, value): | def __setattr__(self, key, value): | ||||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||||
if getattr(self, '_cannot_change_element', False): | |||||
if key in self.elements and isinstance(value, (float, int, bool)): | if key in self.elements and isinstance(value, (float, int, bool)): | ||||
self.elements[key].fill_value(value) | self.elements[key].fill_value(value) | ||||
return | return | ||||
@@ -109,6 +109,14 @@ class Metric: | |||||
raise RuntimeError("Please use register_element() function to add Element.") | raise RuntimeError("Please use register_element() function to add Element.") | ||||
object.__setattr__(self, key, value) | object.__setattr__(self, key, value) | ||||
# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning | |||||
def __getattr__(self, name: str) -> Element: | |||||
if 'elements' in self.__dict__: | |||||
elements = self.__dict__['elements'] | |||||
if name in elements: | |||||
return elements[name] | |||||
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name)) | |||||
def _wrap_update(self, update): | def _wrap_update(self, update): | ||||
@functools.wraps(update) | @functools.wraps(update) | ||||
def _wrap_update(*args, **kwargs): | def _wrap_update(*args, **kwargs): | ||||
@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod | |||||
f"encoding_type." | f"encoding_type." | ||||
tags = tags.replace(tag, '') # 删除该值 | tags = tags.replace(tag, '') # 删除该值 | ||||
if tags: # 如果不为空,说明出现了未使用的tag | if tags: # 如果不为空,说明出现了未使用的tag | ||||
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||||
logger.warning(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||||
"encoding_type.") | "encoding_type.") | ||||
@@ -212,7 +212,7 @@ class SpanFPreRecMetric(Metric): | |||||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | :param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | ||||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | ||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | ||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据 sampler 是否使用分布式 | |||||
进行自动设置。 | 进行自动设置。 | ||||
""" | """ | ||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
@@ -13,7 +13,6 @@ from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .utils import create_array | from .utils import create_array | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
@@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | ||||
drop_last: bool = False, seed: int = None, **kwargs): | |||||
drop_last: bool = False, seed: int = 0, **kwargs): | |||||
super().__init__() | super().__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
@@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): | |||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||||
super().__init__() | super().__init__() | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
@@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -12,7 +12,6 @@ import numpy as np | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
@@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): | |||||
:param seed: 随机数种子。 | :param seed: 随机数种子。 | ||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | :param kwargs: 用户不需要使用,fastNLP 内部使用 | ||||
""" | """ | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||||
super(RandomSampler, self).__init__() | super(RandomSampler, self).__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -7,7 +7,6 @@ __all__ = [ | |||||
from typing import List, Union | from typing import List, Union | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
import numpy as np | import numpy as np | ||||
@@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:param seed: 设置的随机数种子 | :param seed: 设置的随机数种子 | ||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
# 多卡的相关的参数 | # 多卡的相关的参数 | ||||
self.num_replicas = kwargs.get('num_replicas', 1) | self.num_replicas = kwargs.get('num_replicas', 1) | ||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'cache_results', | 'cache_results', | ||||
'is_jittor_module', | |||||
'is_jittor_dataset', | 'is_jittor_dataset', | ||||
'jittor_collate_wraps', | 'jittor_collate_wraps', | ||||
'paddle_to', | 'paddle_to', | ||||
@@ -9,8 +10,14 @@ __all__ = [ | |||||
'is_in_paddle_dist', | 'is_in_paddle_dist', | ||||
'is_in_fnlp_paddle_dist', | 'is_in_fnlp_paddle_dist', | ||||
'is_in_paddle_launch_dist', | 'is_in_paddle_launch_dist', | ||||
'is_paddle_module', | |||||
'f_rich_progress', | 'f_rich_progress', | ||||
'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
'is_torch_module', | |||||
'get_oneflow_device', | |||||
'oneflow_move_data_to_device', | |||||
'is_oneflow_module', | |||||
'is_in_oneflow_dist', | |||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
@@ -28,11 +35,12 @@ __all__ = [ | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module | |||||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | ||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_utils import torch_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device, is_torch_module | |||||
from .oneflow_utils import oneflow_move_data_to_device, is_oneflow_module, is_in_oneflow_dist, get_oneflow_device | |||||
from .utils import * | from .utils import * | ||||
from .tqdm_progress import f_tqdm_progress | from .tqdm_progress import f_tqdm_progress | ||||
from .seq_len_to_mask import seq_len_to_mask | from .seq_len_to_mask import seq_len_to_mask | ||||
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'is_jittor_module', | |||||
'is_jittor_dataset', | 'is_jittor_dataset', | ||||
'jittor_collate_wraps' | |||||
'jittor_collate_wraps', | |||||
] | ] | ||||
from collections.abc import Mapping, Callable | from collections.abc import Mapping, Callable | ||||
@@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: | |||||
from fastNLP.core.dataset import Instance | from fastNLP.core.dataset import Instance | ||||
def is_jittor_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``jittor`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, jt.Module) | |||||
except BaseException: | |||||
return False | |||||
def is_jittor_dataset(dataset) -> bool: | def is_jittor_dataset(dataset) -> bool: | ||||
""" | """ | ||||
@@ -0,0 +1,69 @@ | |||||
import os | |||||
from typing import Any, Union, Optional | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
__all__ = [ | |||||
'get_oneflow_device' | |||||
'oneflow_move_data_to_device', | |||||
'is_oneflow_module', | |||||
'is_in_oneflow_dist', | |||||
] | |||||
from .utils import apply_to_collection | |||||
def get_oneflow_device(device): | |||||
""" | |||||
构造一个 :class:`oneflow.device` 实例并返回。 | |||||
:param device: 字符串或 gpu 编号 | |||||
:return: :class:`oneflow.device` | |||||
""" | |||||
if isinstance(device, oneflow.device): | |||||
return device | |||||
if isinstance(device, int): | |||||
return oneflow.device("cuda", device) | |||||
if isinstance(device, str): | |||||
return oneflow.device(device) | |||||
raise RuntimeError(f"Cannot get `oneflow.device` from {device}.") | |||||
def oneflow_move_data_to_device(batch: Any, device: Optional[Union[str, "oneflow.device"]] = None) -> Any: | |||||
r""" | |||||
在 **oneflow** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||||
:param batch: 需要迁移的数据; | |||||
:param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; | |||||
:return: 迁移到新设备上的数据集合; | |||||
""" | |||||
if device is None: | |||||
return batch | |||||
def batch_to(data: Any) -> Any: | |||||
data_output = data.to(device) | |||||
if data_output is not None: | |||||
return data_output | |||||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||||
return data | |||||
return apply_to_collection(batch, dtype=oneflow.Tensor, function=batch_to) | |||||
def is_oneflow_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`oneflow.nn.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``oneflow`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, oneflow.nn.Module) | |||||
except BaseException: | |||||
return False | |||||
def is_in_oneflow_dist() -> bool: | |||||
""" | |||||
判断是否处于 **oneflow** 分布式的进程下。 | |||||
""" | |||||
return "GLOG_log_dir" in os.environ |
@@ -6,6 +6,7 @@ __all__ = [ | |||||
"is_in_paddle_dist", | "is_in_paddle_dist", | ||||
"is_in_fnlp_paddle_dist", | "is_in_fnlp_paddle_dist", | ||||
"is_in_paddle_launch_dist", | "is_in_paddle_launch_dist", | ||||
"is_paddle_module", | |||||
] | ] | ||||
import os | import os | ||||
@@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: | |||||
""" | """ | ||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | ||||
""" | """ | ||||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
def is_paddle_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``paddle`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, paddle.nn.Layer) | |||||
except BaseException: | |||||
return False |
@@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: | |||||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | ||||
__all__ = [ | __all__ = [ | ||||
'torch_move_data_to_device' | |||||
'torch_move_data_to_device', | |||||
'is_torch_module', | |||||
] | ] | ||||
from .utils import apply_to_collection | from .utils import apply_to_collection | ||||
@@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev | |||||
dtype = TorchTransferableDataType | dtype = TorchTransferableDataType | ||||
return apply_to_collection(batch, dtype=dtype, function=batch_to) | return apply_to_collection(batch, dtype=dtype, function=batch_to) | ||||
def is_torch_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``torch`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, torch.nn.Module) | |||||
except BaseException: | |||||
return False |
@@ -554,7 +554,7 @@ def deprecated(help_message: Optional[str] = None): | |||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
func_hash = hash(deprecated_function) | func_hash = hash(deprecated_function) | ||||
if func_hash not in _emitted_deprecation_warnings: | if func_hash not in _emitted_deprecation_warnings: | ||||
logger.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||||
logger.warning(warning_msg, category=FutureWarning, stacklevel=2) | |||||
_emitted_deprecation_warnings.add(func_hash) | _emitted_deprecation_warnings.add(func_hash) | ||||
return deprecated_function(*args, **kwargs) | return deprecated_function(*args, **kwargs) | ||||
@@ -630,7 +630,7 @@ def is_notebook(): | |||||
def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | ||||
""" | """ | ||||
讲一个 nested 的 dict 转成 flat 的 dict,例如 | |||||
将一个 nested 的 dict 转成 flat 的 dict,例如 | |||||
ex:: | ex:: | ||||
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | ||||
@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
if index in matrix: | if index in matrix: | ||||
logger.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||||
logger.warning(f"Word has more than one vector in embedding file. Set logger level to " | |||||
f"DEBUG for detail.") | f"DEBUG for detail.") | ||||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | ||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | ||||
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
found_count += 1 | found_count += 1 | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
else: | else: | ||||
logger.error("Error occurred at the {} line.".format(idx)) | logger.error("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
@@ -22,5 +22,7 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and | |||||
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | ||||
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | ||||
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | ||||
_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_import | |||||
_NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import | |||||
_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") | _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") |
@@ -8,7 +8,7 @@ from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VIS | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | from fastNLP.envs.utils import _module_available, get_gpu_count | ||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor', 'oneflow'] | |||||
def _set_backend(): | def _set_backend(): | ||||
@@ -145,6 +145,9 @@ def set_env(global_seed=None): | |||||
if backend == 'torch': | if backend == 'torch': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
if backend == 'oneflow': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
def dump_fastnlp_backend(default:bool = False, backend=None): | def dump_fastnlp_backend(default:bool = False, backend=None): | ||||
""" | """ | ||||
@@ -50,6 +50,15 @@ def set_env_on_import_jittor(): | |||||
if 'log_silent' not in os.environ: | if 'log_silent' not in os.environ: | ||||
os.environ['log_silent'] = '1' | os.environ['log_silent'] = '1' | ||||
def set_env_on_import_oneflow(): | |||||
if 'GLOG_log_dir' in os.environ: | |||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ['RANK'] | |||||
if int(os.environ.get(FASTNLP_REMOVE_LOCAL_RANK, 1)): | |||||
remove_local_rank_in_argv() | |||||
if 'GLOG_log_dir' in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
os.environ[FASTNLP_BACKEND_LAUNCH] = '1' | |||||
def set_env_on_import(): | def set_env_on_import(): | ||||
""" | """ | ||||
@@ -61,6 +70,7 @@ def set_env_on_import(): | |||||
set_env_on_import_torch() | set_env_on_import_torch() | ||||
set_env_on_import_paddle() | set_env_on_import_paddle() | ||||
set_env_on_import_jittor() | set_env_on_import_jittor() | ||||
set_env_on_import_oneflow() | |||||
# fastNLP 内部使用的一些变量 | # fastNLP 内部使用的一些变量 | ||||
if FASTNLP_LAUNCH_TIME not in os.environ: | if FASTNLP_LAUNCH_TIME not in os.environ: | ||||
@@ -245,8 +245,9 @@ class DataBundle: | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | ||||
progress_desc=progress_desc, progress_bar=progress_bar) | progress_desc=progress_desc, progress_bar=progress_bar) | ||||
@@ -284,8 +285,9 @@ class DataBundle: | |||||
res = {} | res = {} | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | ||||
modify_fields=modify_fields, | modify_fields=modify_fields, | ||||
@@ -317,8 +319,9 @@ class DataBundle: | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | ||||
progress_desc=progress_desc) | progress_desc=progress_desc) | ||||
return self | return self | ||||
@@ -349,8 +352,9 @@ class DataBundle: | |||||
res = {} | res = {} | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | ||||
progress_bar=progress_bar, progress_desc=progress_desc) | progress_bar=progress_bar, progress_desc=progress_desc) | ||||
return res | return res | ||||
@@ -91,7 +91,7 @@ class EmbedLoader: | |||||
hit_flags[index] = True | hit_flags[index] = True | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
else: | else: | ||||
logging.error("Error occurred at the {} line.".format(idx)) | logging.error("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
@@ -156,7 +156,7 @@ class EmbedLoader: | |||||
found_pad = True | found_pad = True | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
logger.warn("Error occurred at the {} line.".format(idx)) | |||||
logger.warning("Error occurred at the {} line.".format(idx)) | |||||
pass | pass | ||||
else: | else: | ||||
logging.error("Error occurred at the {} line.".format(idx)) | logging.error("Error occurred at the {} line.".format(idx)) | ||||
@@ -345,7 +345,7 @@ class SST2Loader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if 'test' in os.path.split(path)[1]: | if 'test' in os.path.split(path)[1]: | ||||
logger.warn("SST2's test file has no target.") | |||||
logger.warning("SST2's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -55,7 +55,7 @@ class MNLILoader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | ||||
logger.warn("MNLI's test file has no target.") | |||||
logger.warning("MNLI's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
logger.warn("QNLI's test file has no target.") | |||||
logger.warning("QNLI's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -289,7 +289,7 @@ class RTELoader(Loader): | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
logger.warn("RTE's test file has no target.") | |||||
logger.warning("RTE's test file has no target.") | |||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if line: | if line: | ||||
@@ -146,7 +146,7 @@ class MatchingBertPipe(Pipe): | |||||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | ||||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | ||||
f"data set but not in train data set!." | f"data set but not in train data set!." | ||||
logger.warn(warn_msg) | |||||
logger.warning(warn_msg) | |||||
print(warn_msg) | print(warn_msg) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||
@@ -291,7 +291,7 @@ class MatchingPipe(Pipe): | |||||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | ||||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | ||||
f"data set but not in train data set!." | f"data set but not in train data set!." | ||||
logger.warn(warn_msg) | |||||
logger.warning(warn_msg) | |||||
print(warn_msg) | print(warn_msg) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||