@@ -69,6 +69,7 @@ __all__ = [ | |||||
# metrics | # metrics | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
@@ -2,7 +2,6 @@ __all__ = [ | |||||
'Callback', | 'Callback', | ||||
'Event', | 'Event', | ||||
'Filter', | 'Filter', | ||||
'CallbackManager', | |||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
@@ -49,12 +49,17 @@ class Callback: | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
r""" | r""" | ||||
在 `Trainer` 初始化后会被触发; | 在 `Trainer` 初始化后会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
:param driver: ``Trainer`` 中的 ``driver`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_sanity_check_begin(self, trainer): | def on_sanity_check_begin(self, trainer): | ||||
r""" | r""" | ||||
在 '预跑'检测 开始前会被触发; | 在 '预跑'检测 开始前会被触发; | ||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -62,9 +67,8 @@ class Callback: | |||||
r""" | r""" | ||||
在 '预跑'检测 开始后会被触发; | 在 '预跑'检测 开始后会被触发; | ||||
:param trainer: | |||||
:param sanity_check_res: 预跑的 evaluate 结果 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -72,8 +76,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练开始前会被触发; | 在训练开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -81,8 +84,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练完成后会被触发; | 在训练完成后会被触发; | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -90,8 +92,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 开始前会被触发; | 在训练过程中的每一个 epoch 开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -99,8 +100,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -108,8 +108,7 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中准备取出下一个 batch 的数据时触发 | 在训练过程中准备取出下一个 batch 的数据时触发 | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
@@ -117,178 +116,161 @@ class Callback: | |||||
r""" | r""" | ||||
在训练过程中拿到当前的 batch 数据后会被触发; | 在训练过程中拿到当前的 batch 数据后会被触发; | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_begin(self, trainer, batch, indices): | def on_train_batch_begin(self, trainer, batch, indices): | ||||
r""" | r""" | ||||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||||
在取得数据,执行完 ``input_mapping`` (如果 ``Trainer`` 传有该参数),并且移动 ``batch`` 中的 ``tensor`` 到了指定设备。 | |||||
其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 | |||||
如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 ``key`` 或 修改其中的 ``value`` 会影响到输入到 ``model`` 的中的 ``batch`` 数据。 | |||||
:param trainer: `fastNLP.Trainer` | |||||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值, | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 | |||||
:param list[int] indices: 当前的 ``batch`` 是 ``dataset`` 中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, | |||||
其它时候为 None 。 | 其它时候为 None 。 | ||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
""" | |||||
r""" | |||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | ||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | ||||
执行。 | 执行。 | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, trainer, exception): | def on_exception(self, trainer, exception): | ||||
""" | |||||
r""" | |||||
在训练过程遇到异常时调用。 | 在训练过程遇到异常时调用。 | ||||
:param trainer: | |||||
:param exception: 遭遇的异常。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param exception: 遭遇的异常; | |||||
""" | """ | ||||
pass | pass | ||||
def on_save_model(self, trainer): | def on_save_model(self, trainer): | ||||
""" | |||||
当将要保存模型时调用,此刻模型还未保存。 | |||||
r""" | |||||
当调用 Trainer.save_model() 时调用,此刻模型还未保存。 | |||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_model(self, trainer): | def on_load_model(self, trainer): | ||||
""" | |||||
当将要加载模型时调用,此刻模型还未加载。 | |||||
r""" | |||||
当调用 Trainer.load_model() 加载模型时调用,此刻模型还未加载。 | |||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | |||||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||||
r""" | |||||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 Trainer.save_checkpoint() 函数时),该函数用于保存当前 callback 在恢复需要的相关数据。 | |||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
r""" | r""" | ||||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||||
的返回值。 | |||||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态), | |||||
参数 states 为 on_save_checkpoint() 的返回值。 | |||||
:param trainer: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param states: | :param states: | ||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
""" | |||||
r""" | |||||
在 backward 前执行。 | 在 backward 前执行。 | ||||
:param trainer: | |||||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
""" | |||||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||||
r""" | |||||
在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, | |||||
因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 ``step`` 各卡上梯度不一致的问题。 | |||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_optimizers_step(self, trainer, optimizers): | def on_before_optimizers_step(self, trainer, optimizers): | ||||
""" | |||||
r""" | |||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_optimizers_step(self, trainer, optimizers): | def on_after_optimizers_step(self, trainer, optimizers): | ||||
""" | |||||
r""" | |||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
""" | |||||
r""" | |||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_zero_grad(self, trainer, optimizers): | def on_after_zero_grad(self, trainer, optimizers): | ||||
""" | |||||
r""" | |||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||||
""" | """ | ||||
pass | pass | ||||
def on_evaluate_begin(self, trainer): | def on_evaluate_begin(self, trainer): | ||||
""" | |||||
r""" | |||||
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | ||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | ||||
:param trainer: | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
""" | """ | ||||
pass | pass | ||||
def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
""" | |||||
r""" | |||||
结束 evaluate 时调用,并把 evaluate 的结果传入。 | 结束 evaluate 时调用,并把 evaluate 的结果传入。 | ||||
:param trainer: | |||||
:param results: Evaluate 的结果,一般是个 dict 。 | |||||
:return: | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param results: ``Trainer`` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||||
""" | """ | ||||
pass | pass | ||||
@property | @property | ||||
def callback_name(self): | def callback_name(self): | ||||
""" | |||||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||||
r""" | |||||
``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 ``on_load_checkpoint()`` 函数。 | |||||
:return: | |||||
:return: 返回用于区分该 ``callback`` 实例的 ``name``; | |||||
""" | """ | ||||
return self.__class__.__name__ | return self.__class__.__name__ | ||||
@property | @property | ||||
def need_reproducible_sampler(self) -> bool: | def need_reproducible_sampler(self) -> bool: | ||||
""" | |||||
r""" | |||||
当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | 当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | ||||
:return: | |||||
""" | """ | ||||
return False | return False | ||||
@@ -30,20 +30,20 @@ def check_legality(fn): | |||||
class Event: | class Event: | ||||
""" | |||||
与 Trainer.on 函数配合使用,达到控制 callback 函数运行时机的目的。 | |||||
:param value: Trainer 的 callback 时机。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
""" | |||||
every: Optional[int] | every: Optional[int] | ||||
once: Optional[int] | once: Optional[int] | ||||
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | ||||
filter_fn: Optional[Callable] = None): | filter_fn: Optional[Callable] = None): | ||||
""" | |||||
请勿直接使用本对象,而是通过调用 Event.on_after_trainer_initialized() 等方式调用。 | |||||
:param value: Trainer 的 callback 时机。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
""" | |||||
self.every = every | self.every = every | ||||
self.once = once | self.once = once | ||||
self.filter_fn = filter_fn | self.filter_fn = filter_fn | ||||
@@ -456,7 +456,7 @@ class Event: | |||||
class Filter: | class Filter: | ||||
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): | def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): | ||||
r""" | r""" | ||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率。 | |||||
:param every: 表示一个函数隔多少次运行一次; | :param every: 表示一个函数隔多少次运行一次; | ||||
:param once: 表示一个函数只运行一次; | :param once: 表示一个函数只运行一次; | ||||
@@ -2,10 +2,6 @@ import inspect | |||||
from typing import List, Optional, Dict, Sequence | from typing import List, Optional, Dict, Sequence | ||||
from collections import defaultdict | from collections import defaultdict | ||||
__all__ = [ | |||||
'CallbackManager' | |||||
] | |||||
from .callback_event import Event | from .callback_event import Event | ||||
from .callback import Callback | from .callback import Callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -25,7 +21,7 @@ def _transfer(func): | |||||
for callback_fn in manager.callback_fns[func.__name__]: | for callback_fn in manager.callback_fns[func.__name__]: | ||||
try: | try: | ||||
callback_fn(*arg, **kwargs) | callback_fn(*arg, **kwargs) | ||||
except EarlyStopException as e: | |||||
except (EarlyStopException, KeyboardInterrupt) as e: | |||||
raise e | raise e | ||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | ||||
@@ -33,11 +29,10 @@ def _transfer(func): | |||||
return wrapper | return wrapper | ||||
def prepare_callbacks(callbacks, progress_bar): | |||||
def prepare_callbacks(callbacks, progress_bar: str): | |||||
""" | """ | ||||
:param callbacks: | |||||
:param progress_bar: | |||||
:param callbacks: 对用户传入的类 ``callback`` 进行检查,查看是否是否继承了我们的 ``Callback`` 类; | |||||
:param progress_bar: 选择怎样的 ``progress_bar`` 给 ``Trainer`` 使用; | |||||
:return: | :return: | ||||
""" | """ | ||||
_callbacks = [] | _callbacks = [] | ||||
@@ -85,7 +80,7 @@ class CallbackManager: | |||||
2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | 2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | ||||
3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | 3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | ||||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | |||||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 类; | |||||
""" | """ | ||||
self._need_reproducible_sampler = False | self._need_reproducible_sampler = False | ||||
@@ -162,7 +157,6 @@ class CallbackManager: | |||||
"filter_states": {"on_train_begin": filter1.state_dict(), ...} | "filter_states": {"on_train_begin": filter1.state_dict(), ...} | ||||
} | } | ||||
} | } | ||||
""" | """ | ||||
states = {} | states = {} | ||||
@@ -9,58 +9,59 @@ import sys | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .topk_saver import TopkSaver | from .topk_saver import TopkSaver | ||||
from .callback import Callback | from .callback import Callback | ||||
from ..utils.exceptions import EarlyStopException | |||||
class CheckpointCallback(Callback): | class CheckpointCallback(Callback): | ||||
""" | |||||
保存 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||||
- {save_object}-last/ # 最后一个 epoch 的保存 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None, | |||||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | |||||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param every_n_batches: 多少个 batch 保存一次。 | |||||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param topk: 保存 monitor 结果 topK 个。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 EarlyStopException 。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: | |||||
""" | |||||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | ||||
every_n_batches: Optional[int] = None, last: bool = False, | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, | |||||
every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], | |||||
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | ||||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | ||||
save_evaluate_results=True, **kwargs): | save_evaluate_results=True, **kwargs): | ||||
""" | |||||
保存 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||||
- {save_object}-last/ # 最后一个 epoch 的保存 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None, | |||||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | |||||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param every_n_batches: 多少个 batch 保存一次。 | |||||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param topk: 保存 monitor 结果 topK 个。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
if every_n_epochs is not None: | if every_n_epochs is not None: | ||||
if not isinstance(every_n_epochs, int) or every_n_epochs < 1: | if not isinstance(every_n_epochs, int) or every_n_epochs < 1: | ||||
@@ -132,10 +133,6 @@ class CheckpointCallback(Callback): | |||||
self.topk_saver.save(trainer, folder_name=folder_name) | self.topk_saver.save(trainer, folder_name=folder_name) | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | |||||
保存状态,以便之后可以继续使用 | |||||
""" | |||||
states = {} | states = {} | ||||
states['topk_saver'] = self.topk_saver.state_dict() | states['topk_saver'] = self.topk_saver.state_dict() | ||||
return states | return states | ||||
@@ -9,22 +9,23 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class EarlyStopCallback(HasMonitorCallback): | class EarlyStopCallback(HasMonitorCallback): | ||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||||
""" | |||||
""" | |||||
用于 early stop 的 callback 。当监控的结果连续多少次没有变好边 raise 一个 EarlyStopException 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param patience: 多少次 evaluate 不没有提升就停止。 | |||||
""" | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param patience: 多少次 evaluate 不没有提升就停止。 | |||||
""" | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | ||||
self.wait = 0 | self.wait = 0 | ||||
self.patience = patience | self.patience = patience | ||||
@@ -42,7 +43,7 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
# 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | # 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | ||||
if self.wait >= self.patience: | if self.wait >= self.patience: | ||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | raise EarlyStopException(f"After {self.wait} validations, no improvement for " | ||||
f"metric `{self._real_monitor}`") | |||||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | |||||
def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
# 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | # 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | ||||
@@ -1,9 +1,12 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'FitlogCallback' | 'FitlogCallback' | ||||
] | ] | ||||
import os | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from ...envs import _module_available | from ...envs import _module_available | ||||
from ...envs import get_global_rank | from ...envs import get_global_rank | ||||
from ..log import logger | |||||
if _module_available('fitlog'): | if _module_available('fitlog'): | ||||
import fitlog | import fitlog | ||||
@@ -11,7 +14,9 @@ if _module_available('fitlog'): | |||||
class FitlogCallback(HasMonitorCallback): | class FitlogCallback(HasMonitorCallback): | ||||
""" | """ | ||||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | 自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | ||||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。 | |||||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 | |||||
``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 | |||||
``Callback`` ,例如 :class:`~fastNLP.CheckpointCallback` )下的文件夹名称。 | |||||
:param monitor: 监控的 metric 值。 | :param monitor: 监控的 metric 值。 | ||||
@@ -38,6 +43,14 @@ class FitlogCallback(HasMonitorCallback): | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | ||||
fitlog.debug() | fitlog.debug() | ||||
super().on_after_trainer_initialized(trainer, driver) | |||||
fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME']) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | |||||
if self.monitor is None: | |||||
logger.rank_zero_warning(f"No monitor set for {self.__class__.__name__}. Therefore, no best metric will " | |||||
f"be logged.") | |||||
def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
results = self.itemize_results(results) | results = self.itemize_results(results) | ||||
@@ -16,11 +16,6 @@ from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class CanItemDataType(ABC): | class CanItemDataType(ABC): | ||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | @classmethod | ||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | ||||
if cls is CanItemDataType: | if cls is CanItemDataType: | ||||
@@ -30,15 +25,22 @@ class CanItemDataType(ABC): | |||||
class ResultsMonitor: | class ResultsMonitor: | ||||
""" | |||||
可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
""" | |||||
def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): | def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): | ||||
""" | |||||
可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
""" | |||||
self.set_monitor(monitor, larger_better) | self.set_monitor(monitor, larger_better) | ||||
def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
@@ -66,9 +68,9 @@ class ResultsMonitor: | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | def get_monitor_value(self, results:Dict)->Union[float, None]: | ||||
""" | """ | ||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 最长公共字符串算法 匹配的方式寻找。 | |||||
:param results: | |||||
:param results: 评测结果。 | |||||
:return: 如果为 None ,表明此次没有找到合适的monitor | :return: 如果为 None ,表明此次没有找到合适的monitor | ||||
""" | """ | ||||
if len(results) == 0 or self.monitor is None: | if len(results) == 0 or self.monitor is None: | ||||
@@ -113,7 +115,7 @@ class ResultsMonitor: | |||||
""" | """ | ||||
检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 | 检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 | ||||
:param results: on_valid_ends() 接口中传入的 evaluation 结果。 | |||||
:param results: evaluation 结果。 | |||||
:param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 | :param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -166,24 +168,24 @@ class ResultsMonitor: | |||||
class HasMonitorCallback(ResultsMonitor, Callback): | class HasMonitorCallback(ResultsMonitor, Callback): | ||||
""" | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
""" | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | def __init__(self, monitor, larger_better, must_have_monitor=False): | ||||
""" | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
""" | |||||
super().__init__(monitor, larger_better) | super().__init__(monitor, larger_better) | ||||
self.must_have_monitor = must_have_monitor | self.must_have_monitor = must_have_monitor | ||||
@@ -212,16 +214,23 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||||
class ExecuteOnceBetterMonitor(HasMonitorCallback): | class ExecuteOnceBetterMonitor(HasMonitorCallback): | ||||
""" | |||||
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | |||||
""" | |||||
def __init__(self, monitor, larger_better, execute_fn): | def __init__(self, monitor, larger_better, execute_fn): | ||||
""" | |||||
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | |||||
""" | |||||
super().__init__(monitor, larger_better, must_have_monitor=True) | super().__init__(monitor, larger_better, must_have_monitor=True) | ||||
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | ||||
self.execute_fn = execute_fn | self.execute_fn = execute_fn | ||||
@@ -14,34 +14,34 @@ from fastNLP.envs import all_rank_call_context | |||||
class LoadBestModelCallback(HasMonitorCallback): | class LoadBestModelCallback(HasMonitorCallback): | ||||
""" | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | |||||
最好的模型。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | |||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||||
:param only_state_dict: 是否只保存模型的参数。当 model_save_fn 不为空时,该值无效。 | |||||
:param model_save_fn: 保存 model 的函数,与 model_load_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的保存。 | |||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的加载。 | |||||
:param delete_after_train: 在训练结束后是否删掉模型。 | |||||
""" | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | ||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
model_load_fn:Optional[Callable] = None, | model_load_fn:Optional[Callable] = None, | ||||
delete_after_train:bool = True): | delete_after_train:bool = True): | ||||
""" | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | |||||
最好的模型。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | |||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||||
:param only_state_dict: 是否只保存模型的参数。当 model_save_fn 不为空时,该值无效。 | |||||
:param model_save_fn: 保存 model 的函数,与 model_load_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的保存。 | |||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的加载。 | |||||
:param delete_after_train: 在训练结束后是否删掉模型。 | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | ||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | assert callable(model_load_fn), "`model_load_fn` must be a callable object." | ||||
@@ -6,14 +6,14 @@ __all__ = [ | |||||
class LRSchedCallback(Callback): | class LRSchedCallback(Callback): | ||||
def __init__(self, scheduler, step_on:str='batch'): | |||||
""" | |||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | |||||
""" | |||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | |||||
:param scheduler: 实现了 step() 函数的对象 | |||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 | |||||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 | |||||
""" | |||||
:param scheduler: 实现了 step() 函数的对象 | |||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数 | |||||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。 | |||||
""" | |||||
def __init__(self, scheduler, step_on:str='batch'): | |||||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | ||||
"step function." | "step function." | ||||
self.scheduler = scheduler | self.scheduler = scheduler | ||||
@@ -11,6 +11,67 @@ from .topk_saver import TopkSaver | |||||
class MoreEvaluateCallback(HasMonitorCallback): | class MoreEvaluateCallback(HasMonitorCallback): | ||||
""" | |||||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
:param dataloaders: 需要评估的数据 | |||||
:param metrics: 使用的 metrics 。 | |||||
:param evaluate_every: 用来控制 ``Trainer`` 内部的 ``Evaluator`` 验证的频率,其可以为负数、正数或者函数: | |||||
1. 为负数时表示每隔几个 ``epoch`` evaluate 一次; | |||||
2. 为正数则表示每隔几个 ``batch`` evaluate 一次; | |||||
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate; | |||||
.. note:: | |||||
如果参数 ``evaluate_every`` 为函数,其应当类似: | |||||
>>> def my_evaluate_every(trainer) -> bool: | |||||
... if (trainer.global_forward_batches+1) % 1000 == 0: | |||||
... return True | |||||
... else: | |||||
... return False | |||||
该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次; | |||||
另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证; | |||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 | |||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | |||||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||||
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||||
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||||
:param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param save_kwargs: dict。更多的保存相关的参数。 | |||||
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。 | |||||
""" | |||||
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | ||||
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | ||||
evaluate_fn=None, num_eval_sanity_batch=2, | evaluate_fn=None, num_eval_sanity_batch=2, | ||||
@@ -18,48 +79,6 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
folder=None, only_state_dict=True, save_object='model', model_save_fn=None, | folder=None, only_state_dict=True, save_object='model', model_save_fn=None, | ||||
save_evaluate_results=True, save_kwargs=None, | save_evaluate_results=True, save_kwargs=None, | ||||
**kwargs): | **kwargs): | ||||
""" | |||||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
:param dataloaders: 需要评估的数据 | |||||
:param metrics: 使用的 metrics 。 | |||||
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch | |||||
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 | |||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | |||||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||||
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||||
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||||
:param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param save_kwargs: dict。更多的保存相关的参数。 | |||||
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。 | |||||
""" | |||||
super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | ||||
must_have_monitor=False) | must_have_monitor=False) | ||||
@@ -39,25 +39,27 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
""" | |||||
在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||||
:param format_json: 是否格式化 json 再打印 | |||||
""" | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
format_json=True): | format_json=True): | ||||
""" | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||||
:param format_json: 是否格式化 json 再打印 | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | ||||
self.print_every = print_every | self.print_every = print_every | ||||
self.progress_bar = f_rich_progress | self.progress_bar = f_rich_progress | ||||
@@ -16,22 +16,24 @@ from .has_monitor_callback import ResultsMonitor | |||||
class Saver: | class Saver: | ||||
""" | |||||
执行保存的对象。保存的文件组织结构为:: | |||||
- folder # 当前初始化的参数 | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- folder_name # 由 save() 调用时传入。 | |||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 | |||||
""" | |||||
def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True, | def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True, | ||||
model_save_fn:Callable=None, **kwargs): | model_save_fn:Callable=None, **kwargs): | ||||
""" | |||||
执行保存的对象。保存的文件组织结构为:: | |||||
- folder # 当前初始化的参数 | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- folder_name # 由 save() 调用时传入。 | |||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||||
""" | |||||
if folder is None: | if folder is None: | ||||
folder = Path.cwd().absolute() | folder = Path.cwd().absolute() | ||||
logger.info(f"Parameter `folder` is None, and we will use {folder} to save and load your model.") | logger.info(f"Parameter `folder` is None, and we will use {folder} to save and load your model.") | ||||
@@ -46,7 +48,7 @@ class Saver: | |||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
self.save_object = save_object | self.save_object = save_object | ||||
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' | |||||
self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' | |||||
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
@@ -79,8 +81,8 @@ class Saver: | |||||
""" | """ | ||||
以 json 格式保存 results 到 path 中 | 以 json 格式保存 results 到 path 中 | ||||
:param results: | |||||
:param path: | |||||
:param results: 一般是评测后的结果。 | |||||
:param path: 保存的文件名 | |||||
:return: | :return: | ||||
""" | """ | ||||
with open(path, 'w', encoding='utf8') as f: | with open(path, 'w', encoding='utf8') as f: | ||||
@@ -117,12 +119,12 @@ class Saver: | |||||
class TopkQueue: | class TopkQueue: | ||||
def __init__(self, topk): | |||||
""" | |||||
用于维护处于 topk 的 key, value 对。 | |||||
""" | |||||
用于维护处于 topk 的 key, value 对。 | |||||
:param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 | |||||
""" | |||||
:param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 | |||||
""" | |||||
def __init__(self, topk): | |||||
assert isinstance(topk, int) | assert isinstance(topk, int) | ||||
self.topk = topk | self.topk = topk | ||||
self.topk_dict = {} # 其中 key 为保存的内容, value 是对应的性能。 | self.topk_dict = {} # 其中 key 为保存的内容, value 是对应的性能。 | ||||
@@ -170,31 +172,39 @@ class TopkQueue: | |||||
class TopkSaver(ResultsMonitor, Saver): | class TopkSaver(ResultsMonitor, Saver): | ||||
""" | |||||
用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 monitor 是否越大越好。 | |||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 | |||||
""" | |||||
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | ||||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | ||||
**kwargs): | **kwargs): | ||||
""" | |||||
用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为:: | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | |||||
:param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 | |||||
最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, | |||||
接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请 | |||||
返回 None 。 | |||||
:param larger_better: 该 monitor 是否越大越好。 | |||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||||
""" | |||||
ResultsMonitor.__init__(self, monitor, larger_better) | ResultsMonitor.__init__(self, monitor, larger_better) | ||||
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | ||||
@@ -1,25 +1,26 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'TorchGradClipCallback' | 'TorchGradClipCallback' | ||||
] | ] | ||||
from typing import Union, List | |||||
from ..callback import Callback | from ..callback import Callback | ||||
class TorchGradClipCallback(Callback): | class TorchGradClipCallback(Callback): | ||||
def __init__(self, clip_value=1, clip_type='norm', parameters=None): | |||||
r""" | |||||
在每次 optimizer update 之前将 parameter 进行 clip | |||||
r""" | |||||
在每次 optimizer update 之前将 parameter 进行 clip 。 | |||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param str clip_type: 支持'norm', 'value'两种: | |||||
:param clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param clip_type: 支持'norm', 'value'两种: | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], | |||||
小于-clip_value的gradient被赋值为-clip_value; | |||||
大于clip_value的gradient被赋值为clip_value. | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], | |||||
小于-clip_value的gradient被赋值为-clip_value;大于clip_value的gradient被赋值为clip_value. | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | |||||
""" | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | |||||
""" | |||||
def __init__(self, clip_value:int=1, clip_type:str='norm', | |||||
parameters:Union["torch.Tensor", List["torch.Tensor"]]=None): | |||||
super().__init__() | super().__init__() | ||||
from torch import nn | from torch import nn | ||||
@@ -2,21 +2,23 @@ __all__ = [ | |||||
'TorchWarmupCallback' | 'TorchWarmupCallback' | ||||
] | ] | ||||
import math | import math | ||||
from typing import Union | |||||
from ..callback import Callback | from ..callback import Callback | ||||
class TorchWarmupCallback(Callback): | class TorchWarmupCallback(Callback): | ||||
def __init__(self, warmup=0.1, schedule='constant'): | |||||
r""" | |||||
调整 learning rate 的 callback 。仅在实际发生参数更新的情况下 | |||||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||||
:param str schedule: 以哪种方式调整。 | |||||
linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0; | |||||
constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||||
""" | |||||
r""" | |||||
调整 learning rate 的 callback 。 | |||||
:param warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||||
:param schedule: 以哪种方式调整。 | |||||
1. linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0; | |||||
2. constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||||
""" | |||||
def __init__(self, warmup:Union[int, float]=0.1, schedule:str='constant'): | |||||
super().__init__() | super().__init__() | ||||
self.warmup = max(warmup, 0.) | self.warmup = max(warmup, 0.) | ||||
@@ -82,16 +82,26 @@ def _get_backend() -> str: | |||||
class Collator: | class Collator: | ||||
def __init__(self, backend='auto'): | |||||
""" | |||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||||
""" | |||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||||
哦安定一个 field 是否可以 pad 的方式为:(1)当前这个 field 是否所有对象都是一样的数据类型;(因此,如果某 field 的数据有些是float | |||||
有些是 int 将知道该 field 被判定为不可 pad 类型。)(2)当前这个 field 是否每个 sample 都具有一样的深度;(因此,例如有个 field 的 | |||||
数据转为 batch 类型后为 [1, [1,2]], 会被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同)(3)当前这个 field 的类 | |||||
型是否是可以 pad (例如 str 类型的数据)。可以通过设置 logger.setLevel('debug') 来打印是判定不可 pad 的原因。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | |||||
todo 补充 code example 。 | |||||
如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 pad_val 设置为 None 实现。 | |||||
如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。 | |||||
Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | |||||
的 Padder 给对应的 field 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | |||||
def __init__(self, backend='auto'): | |||||
self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
self.pack_batch_func = None | self.pack_batch_func = None | ||||
self.ignore_fields = set() | self.ignore_fields = set() | ||||
@@ -264,9 +274,8 @@ class Collator: | |||||
def set_ignore(self, *field_names) -> "Collator": | def set_ignore(self, *field_names) -> "Collator": | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Example:: | |||||
collator.set_ignore('field1', 'field2') | |||||
>>> collator = Collator().set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | ||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | ||||
@@ -9,9 +9,9 @@ class MappingPackerUnpacker: | |||||
""" | """ | ||||
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | ||||
:param batch: | |||||
:param ignore_fields: | |||||
:param input_fields: | |||||
:param batch: 需要 unpack 的 batch 数据。 | |||||
:param ignore_fields: 需要忽略的 field 。 | |||||
:param input_fields: 需要设置为 input 的 field 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
dict_batch = defaultdict(list) | dict_batch = defaultdict(list) | ||||
@@ -29,13 +29,13 @@ class MappingPackerUnpacker: | |||||
class NestedMappingPackerUnpacker: | class NestedMappingPackerUnpacker: | ||||
@staticmethod | @staticmethod | ||||
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict): | |||||
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict: | |||||
""" | """ | ||||
将 nested 的 dict 中的内容展开到一个 flat dict 中 | 将 nested 的 dict 中的内容展开到一个 flat dict 中 | ||||
:param batch: | |||||
:param batch: 需要 unpack 的 batch 数据。 | |||||
:param ignore_fields: 需要忽略的 field 。 | :param ignore_fields: 需要忽略的 field 。 | ||||
:param input_fields: 不需要继续往下衍射的 | |||||
:param input_fields: 需要设置为 input 的 field 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
dict_batch = defaultdict(list) | dict_batch = defaultdict(list) | ||||
@@ -72,8 +72,9 @@ class SequencePackerUnpacker: | |||||
""" | """ | ||||
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | ||||
:param batch: | |||||
:param ignore_fields: 需要忽略的field | |||||
:param batch: 需要 unpack 的 batch 数据。 | |||||
:param ignore_fields: 需要忽略的 field 。 | |||||
:param input_fields: 需要设置为 input 的 field 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
dict_batch = defaultdict(list) | dict_batch = defaultdict(list) | ||||
@@ -90,7 +90,7 @@ class JittorNumberPadder(Padder): | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return jittor.Var(np.array(batch_field, dtype=dtype)) | return jittor.Var(np.array(batch_field, dtype=dtype)) | ||||
@@ -107,7 +107,7 @@ class JittorSequencePadder(Padder): | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
@@ -125,7 +125,7 @@ class JittorTensorPadder(Padder): | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
try: | try: | ||||
if not isinstance(batch_field[0], jittor.Var): | if not isinstance(batch_field[0], jittor.Var): | ||||
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | ||||
@@ -30,53 +30,65 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class NumpyNumberPadder(Padder): | class NumpyNumberPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) 。可以通过: | |||||
>>> NumpyNumberPadder.pad([1, 2, 3]) | |||||
使用。 | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return np.array(batch_field, dtype=dtype) | return np.array(batch_field, dtype=dtype) | ||||
class NumpySequencePadder(Padder): | class NumpySequencePadder(Padder): | ||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
可以通过以下的方式直接使用: | |||||
>>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float) | |||||
[[ 1. -100.] | |||||
[ 1. 2.]] | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | ||||
class NumpyTensorPadder(Padder): | class NumpyTensorPadder(Padder): | ||||
""" | |||||
pad 类似于 [np.array([3, 4]), np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||||
>>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||||
[[ 3. 4.] | |||||
[ 1. -100.]] | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | |||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
try: | try: | ||||
if not isinstance(batch_field[0], np.ndarray): | if not isinstance(batch_field[0], np.ndarray): | ||||
batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | ||||
@@ -1,5 +1,9 @@ | |||||
class Padder: | class Padder: | ||||
""" | |||||
所有 Padder 对象父类,所有的 Padder 对象都会实现 pad(batch_field, pad_val=0, dtype=None) 的静态函数。 | |||||
""" | |||||
def __init__(self, pad_val, dtype): | def __init__(self, pad_val, dtype): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
self.dtype = dtype | self.dtype = dtype | ||||
@@ -8,19 +12,19 @@ class Padder: | |||||
return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) | return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
raise NotImplementedError() | raise NotImplementedError() | ||||
class NullPadder(Padder): | class NullPadder(Padder): | ||||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | |||||
""" | |||||
不进行任何 检查 与 pad 的空 padder 。 | |||||
""" | |||||
不进行任何 检查 与 pad 的空 padder 。 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
def __call__(self, batch_field): | def __call__(self, batch_field): | ||||
@@ -80,55 +80,55 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class PaddleNumberPadder(Padder): | class PaddleNumberPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return paddle.to_tensor(batch_field, dtype=dtype) | return paddle.to_tensor(batch_field, dtype=dtype) | ||||
class PaddleSequencePadder(Padder): | class PaddleSequencePadder(Padder): | ||||
def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
class PaddleTensorPadder(Padder): | class PaddleTensorPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||||
""" | |||||
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
try: | try: | ||||
if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
batch_field = [np.array(field.tolist()) for field in batch_field] | batch_field = [np.array(field.tolist()) for field in batch_field] | ||||
@@ -26,14 +26,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class RawNumberPadder(Padder): | class RawNumberPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -41,24 +41,24 @@ class RawNumberPadder(Padder): | |||||
return batch_field | return batch_field | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
raise NotImplementedError() | raise NotImplementedError() | ||||
class RawSequencePadder(Padder): | class RawSequencePadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
""" | """ | ||||
:param batch_field: | :param batch_field: | ||||
@@ -70,19 +70,19 @@ class RawSequencePadder(Padder): | |||||
class RawTensorPadder(Padder): | class RawTensorPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
""" | """ | ||||
:param batch_field: | :param batch_field: | ||||
@@ -64,54 +64,61 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class TorchNumberPadder(Padder): | class TorchNumberPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return torch.tensor(batch_field, dtype=dtype) | return torch.tensor(batch_field, dtype=dtype) | ||||
class TorchSequencePadder(Padder): | class TorchSequencePadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
class TorchTensorPadder(Padder): | class TorchTensorPadder(Padder): | ||||
""" | |||||
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||||
>>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||||
[[ 3. 4.] | |||||
[ 1. -100.]] | |||||
>>> TorchTensorPadder.pad([torch.LongTensor([3, 4]), torch.LongTensor([1])], pad_val=-100) | |||||
tensor([[ 3, 4], | |||||
[ 1, -100]]) | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | |||||
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
device = None | device = None | ||||
try: | try: | ||||
if not isinstance(batch_field[0], torch.Tensor): | if not isinstance(batch_field[0], torch.Tensor): | ||||
@@ -1,7 +1,15 @@ | |||||
r""" | |||||
``Evaluator`` 是新版 fastNLP 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 fastNLP 中**训练**和**评测**的框架。 | |||||
``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地 | |||||
对您的模型进行全方位地评测。 | |||||
.. note:: | |||||
``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能; | |||||
""" | |||||
from typing import Union, List, Optional, Dict, Callable | from typing import Union, List, Optional, Dict, Callable | ||||
from functools import partial | |||||
from dataclasses import is_dataclass | from dataclasses import is_dataclass | ||||
import sys | |||||
__all__ = [ | __all__ = [ | ||||
'Evaluator' | 'Evaluator' | ||||
@@ -20,60 +28,96 @@ from fastNLP.core.log import logger | |||||
class Evaluator: | class Evaluator: | ||||
""" | |||||
用于评测模型性能好坏的评测器; | |||||
.. note:: | |||||
``Evaluator`` 与 ``Trainer`` 类似,都是使用 ``Driver`` 作为底层来实现评测或者训练,因此大多数与 ``Trainer`` 同名的参数的意义和使用都与 | |||||
``Trainer`` 中的参数相同,对于这些参数,您可以参考 ``Trainer`` 的文档来获取更详细的信息;详见 :class:`~fastNLP.core.controllers.trainer.Trainer`; | |||||
:param model: 训练所需要的模型,例如 ``torch.nn.Module``,等价于 ``Trainer`` 中的 ``model`` 参数; | |||||
:param dataloaders: 用于评测的数据集。如果为多个,您需要使用 ``dict`` 传入,即对每一个数据集标上用于标识它们的标签; | |||||
:param metrics: 评测时使用的指标。注意该参数必须为 ``dict`` 类型,其中 ``key`` 为一个 ``metric`` 的名称,``value`` 为具体的 ``Metric`` 对象。目前支持以下 metrics: | |||||
1. fastNLP 自己的 ``metric``:详见 :class:`fastNLP.core.metrics.Metric`; | |||||
2. torchmetrics; | |||||
3. allennlp.training.metrics; | |||||
4. paddle.metric; | |||||
:param driver: 等价于 ``Trainer`` 中的 ``driver`` 参数; | |||||
.. note:: | |||||
如果在您的脚本中在初始化 ``Evaluator`` 前也初始化了 ``Trainer`` 进行训练,那么强烈建议您直接将 ``trainer.driver`` 传入 ``Evaluator`` 当做该参数的值; | |||||
.. code-block:: | |||||
# 初始化 Trainer | |||||
trainer = Trainer( | |||||
... | |||||
driver='torch', | |||||
device=[0,1] | |||||
) | |||||
trainer.run() | |||||
# 此时再初始化 Evaluator 时应当直接使用 trainer.driver; | |||||
evaluator = Evaluator( | |||||
... | |||||
driver=trainer.driver | |||||
) | |||||
:param device: 等价于 ``Trainer`` 中的 ``device`` 参数; | |||||
:param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``, | |||||
不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`; | |||||
:param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 ``model.evaluate_step`` 还是 | |||||
``model.forward``(不同训练框架所使用的的前向传播函数的方法名称不同); | |||||
1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数; | |||||
2. 如果为 ``str`` 类型,例如为 ``my_evaluate_step_fn``,则尝试寻找 ``model.my_evaluate_step_fn``,如果找不到则直接报错; | |||||
:param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对 | |||||
``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制; | |||||
.. todo:: | |||||
之后链接上 参数匹配 的文档; | |||||
:param output_mapping: 等价于 ``Trainer`` 中的 ``output_mapping`` 参数;对 ``model`` 输出的内容,将通过 ``output_mapping`` 处理之后再输入到 ``metric`` 中; | |||||
:param model_wo_auto_param_call: 等价于 ``Trainer`` 中的 ``model_wo_auto_param_call`` 参数; | |||||
.. note:: | |||||
一个十分需要注意的问题在于 ``model_wo_auto_param_call`` 只会关闭部分的参数匹配,即指挥关闭前向传播时的参数匹配,但是由于 ``Evaluator`` 中 | |||||
``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去; | |||||
:param fp16: 是否在评测时使用 fp16; | |||||
:param verbose: 是否打印 evaluate 的结果; | |||||
:kwargs: | |||||
* *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数; | |||||
* *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数; | |||||
* *model_use_eval_mode* (``bool``) -- | |||||
是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的 | |||||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||||
* *use_dist_sampler* -- | |||||
是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||||
""" | |||||
driver: Driver | driver: Driver | ||||
_evaluate_batch_loop: Loop | _evaluate_batch_loop: Loop | ||||
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, | |||||
def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, | |||||
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | ||||
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | ||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | ||||
fp16: bool = False, verbose: int = 1, **kwargs): | fp16: bool = False, verbose: int = 1, **kwargs): | ||||
""" | |||||
用于对数据进行评测。 | |||||
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | |||||
:param dataloaders: 待评测的数据集。如果为多个,请使用 dict 传入。 | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | |||||
metric ,torchmetrics,allennlpmetrics 等。 | |||||
:param driver: 使用 driver 。 | |||||
:param device: 使用的设备。 | |||||
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, | |||||
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对 | |||||
model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。 | |||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | |||||
:param fp16: 是否使用 fp16 。 | |||||
:param verbose: 是否打印 evaluate 的结果。 | |||||
:kwargs: | |||||
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | |||||
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | |||||
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | |||||
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | |||||
注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
* *model_use_eval_mode* (``bool``) -- | |||||
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 | |||||
dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
* *use_dist_sampler* -- | |||||
是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
* *output_from_new_proc* -- | |||||
应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
* *progress_bar* -- | |||||
evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
""" | |||||
self.model = model | self.model = model | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, | self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, | ||||
@@ -127,19 +171,22 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||||
""" | """ | ||||
该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数; | |||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | ||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
1. 如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
``metric_indicator_name#metric_name`` | |||||
2. 如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
``metric_indicator_name#metric_name#dataloader_name`` (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | 如果存在多个metric,多个dataloader的情况,key的命名规则是 | ||||
metric_indicator_name#metric_name#dataloader_name | |||||
其中 metric_indicator_name 可能不存在。 | |||||
``metric_indicator_name#metric_name#dataloader_name`` | |||||
其中 metric_indicator_name 可能不存在; | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
:return: | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。 | |||||
:return: 返回评测得到的结果,是一个没有嵌套的字典; | |||||
""" | """ | ||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | ||||
@@ -388,5 +435,8 @@ class _MetricsWrapper: | |||||
_results = metric.accumulate() | _results = metric.accumulate() | ||||
else: | else: | ||||
raise RuntimeError(f"Not support `{type(metric)}` for now.") | raise RuntimeError(f"Not support `{type(metric)}` for now.") | ||||
results[metric_name] = _results | |||||
if _results is not None: | |||||
results[metric_name] = _results | |||||
else: | |||||
logger.warning_once(f"Metric:{metric_name} returns None when getting metric results.") | |||||
return results | return results |
@@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params | |||||
class EvaluateBatchLoop(Loop): | class EvaluateBatchLoop(Loop): | ||||
r""" | |||||
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
""" | |||||
def __init__(self, batch_step_fn:Optional[Callable]=None): | def __init__(self, batch_step_fn:Optional[Callable]=None): | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
def run(self, evaluator, dataloader) -> Dict: | def run(self, evaluator, dataloader) -> Dict: | ||||
""" | |||||
r""" | |||||
需要返回在传入的 dataloader 中的 evaluation 结果 | 需要返回在传入的 dataloader 中的 evaluation 结果 | ||||
:param evaluator: Evaluator 对象 | :param evaluator: Evaluator 对象 | ||||
:param dataloader: 当前需要进行 evaluate 的dataloader | |||||
:param dataloader: 当前需要进行评测的dataloader | |||||
:return: | :return: | ||||
""" | """ | ||||
iterator = iter(dataloader) | iterator = iter(dataloader) | ||||
@@ -27,24 +32,32 @@ class EvaluateBatchLoop(Loop): | |||||
while True: | while True: | ||||
try: | try: | ||||
batch = next(iterator) | batch = next(iterator) | ||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
try: | |||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
except BaseException as e: | except BaseException as e: | ||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | if callable(getattr(dataloader, 'get_batch_indices', None)): | ||||
indices = dataloader.get_batch_indices() | indices = dataloader.get_batch_indices() | ||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | logger.error(f"Exception happens when evaluating on samples: {indices}") | ||||
raise e | raise e | ||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | ||||
results = evaluator.get_metric() | results = evaluator.get_metric() | ||||
return results | return results | ||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(evaluator, batch): | def batch_step_fn(evaluator, batch): | ||||
r""" | |||||
针对一个 batch 的数据的评测过程; | |||||
:param evaluator: Evaluator 对象 | |||||
:param batch: 当前需要评测的一个 batch 的数据; | |||||
""" | |||||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | ||||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 | evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -1,3 +1,12 @@ | |||||
r""" | |||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||||
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | |||||
更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | |||||
:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ; | |||||
""" | |||||
from typing import Union | |||||
__all__ = [ | __all__ = [ | ||||
'Loop' | 'Loop' | ||||
@@ -5,13 +14,25 @@ __all__ = [ | |||||
class Loop: | class Loop: | ||||
r""" | |||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``; | |||||
""" | |||||
def run(self, *args, **kwargs): | |||||
""" | |||||
该循环的主要运行过程; | |||||
""" | |||||
def run(self, controller: Union["Trainer", "Evaluator"], dataloader): | |||||
r""" | |||||
遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测; | |||||
.. note:: | |||||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||||
在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | |||||
中对于 callback 函数的调用; | |||||
def step(self, *args, **kwargs): | |||||
""" | """ | ||||
该循环运行过程中的一步; | |||||
@staticmethod | |||||
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | |||||
r""" | |||||
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||||
""" | """ |
@@ -11,43 +11,66 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class TrainBatchLoop(Loop): | class TrainBatchLoop(Loop): | ||||
r""" | |||||
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
""" | |||||
def __init__(self, batch_step_fn: Optional[Callable] = None): | def __init__(self, batch_step_fn: Optional[Callable] = None): | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
def run(self, trainer, dataloader): | def run(self, trainer, dataloader): | ||||
r""" | |||||
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||||
.. note:: | |||||
您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param dataloader: 当前训练所使用的 dataloader; | |||||
""" | |||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | ||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
indices = None | |||||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | ||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
except StopIteration: | |||||
break | |||||
try: | |||||
trainer.on_fetch_data_end() | trainer.on_fetch_data_end() | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
except StopIteration: | |||||
break | |||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
except BaseException as e: | except BaseException as e: | ||||
if indices and not isinstance(e, EarlyStopException): | |||||
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): | |||||
logger.error(f"Exception happens when running on samples: {indices}") | logger.error(f"Exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
trainer.step_evaluate() | trainer.step_evaluate() | ||||
trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(trainer, batch): | def batch_step_fn(trainer, batch): | ||||
r""" | |||||
针对一个 batch 的数据的训练过程; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param batch: 一个 batch 的数据; | |||||
""" | |||||
outputs = trainer.train_step(batch) | outputs = trainer.train_step(batch) | ||||
trainer.backward(outputs) | trainer.backward(outputs) | ||||
trainer.step() | trainer.step() | ||||
@@ -1,19 +1,6 @@ | |||||
""" | |||||
该 Module 用来实现一个用于记载用户 callback 实时数据的 state,该 state 实际上是一个 字典,我们通过复用 __getattr__ 方法来实现类似 | |||||
类属性的字典调用方式; | |||||
提供该类的主要目的在于与 Filter 中的特殊的 filter_fn 合作,方便用户能够使用到自己想要的一切特殊的定制方式; | |||||
这一特殊的 Filter 实现需要用户记录一些特殊的状态值,例如 accuracy 等,但是我们不希望用户将这些状态值直接挂在 trainer 实例上,因为这样会 | |||||
污染 trainer 自己的类属性,从而可能导致一些莫名其妙的 bug; | |||||
我们开放 state 用于用户这一特殊的定制选择; | |||||
""" | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Optional, Dict | from typing import Optional, Dict | ||||
__all__ = [ | __all__ = [ | ||||
'State', | 'State', | ||||
'TrainerState' | 'TrainerState' | ||||
@@ -22,7 +9,8 @@ __all__ = [ | |||||
class State(dict): | class State(dict): | ||||
r""" | r""" | ||||
提供给用户使用的 state; | |||||
提供给用户使用的 ``state``,用来记载您的 ``callback`` 实时数据,该 ``state`` 实际上是一个字典,我们通过复用 ``__getattr__`` 方法来实现类似 | |||||
类属性的字典调用方式; | |||||
为了实现断点重训,用户应当保证其保存的信息都是可序列化的; | 为了实现断点重训,用户应当保证其保存的信息都是可序列化的; | ||||
@@ -1,4 +1,3 @@ | |||||
import inspect | |||||
from typing import Dict | from typing import Dict | ||||
from fastNLP.core.callbacks import CallbackManager | from fastNLP.core.callbacks import CallbackManager | ||||
@@ -7,10 +6,10 @@ from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class TrainerEventTrigger: | class TrainerEventTrigger: | ||||
""" | |||||
r""" | |||||
为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' | 为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' | ||||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 | |||||
模型。 | |||||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 | |||||
模型。 | |||||
""" | """ | ||||
callback_manager: CallbackManager | callback_manager: CallbackManager | ||||
trainer_state: TrainerState | trainer_state: TrainerState | ||||
@@ -90,13 +89,21 @@ class TrainerEventTrigger: | |||||
class _TruncatedDataLoader: | class _TruncatedDataLoader: | ||||
r""" | |||||
``_TruncatedDataLoader`` 用于实现 ``Trainer`` 和 ``Evaluator`` 中的 '预跑' 和 '假跑' 功能: | |||||
1. 预跑 是针对 trainer 的验证而言的,即我们在正式的训练前会先使用 trainer 内置的 evaluator(如果不为 None)评测数量非常少的数据, | |||||
来检验用户的 metric 和 evaluate_dataloader 以及模型是否能够合作完成正确的评测过程; | |||||
2. 假跑 的意思是改变每一个 epoch 中训练或者评测的实际的 batch 的数量,例如改成 10,来让模型快速地迭代整体的训练或者评测流程,来查看 | |||||
整体的过程的正确性; | |||||
``_TruncatedDataLoader`` 的实现也非常简单,我们在该类中内置一个计数器,当迭代器的迭代数量达到指定数值后 ``raise StopIteration``; | |||||
:param dataloader: 可迭代的 dataloader 。 | |||||
:param num_batches: 迭代多少个 batch 就停止。 | |||||
""" | |||||
def __init__(self, dataloader, num_batches: int): | def __init__(self, dataloader, num_batches: int): | ||||
""" | |||||
限制 | |||||
:param dataloader: 可迭代的 dataloader 。 | |||||
:param num_batches: 迭代多少个 batch 就停止。 | |||||
""" | |||||
self.dataloader = dataloader | self.dataloader = dataloader | ||||
self._num_batches = min(num_batches, len(dataloader)) | self._num_batches = min(num_batches, len(dataloader)) | ||||
self._count = 0 | self._count = 0 | ||||
@@ -104,7 +111,6 @@ class _TruncatedDataLoader: | |||||
def __len__(self): | def __len__(self): | ||||
r""" | r""" | ||||
为了在外部调用 `len` 方法时正确地返回当前会迭代的长度; | 为了在外部调用 `len` 方法时正确地返回当前会迭代的长度; | ||||
""" | """ | ||||
return self._num_batches | return self._num_batches | ||||
@@ -127,6 +133,13 @@ class _TruncatedDataLoader: | |||||
def check_evaluate_every(evaluate_every): | def check_evaluate_every(evaluate_every): | ||||
r""" | |||||
检验用户传入的 ``evaluate_every`` 参数是否合法; | |||||
``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; | |||||
主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性; | |||||
""" | |||||
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | ||||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
if callable(evaluate_every): | if callable(evaluate_every): | ||||
@@ -41,7 +41,7 @@ class JittorDataLoader: | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | collate_fn: Union[None, str, Callable] = "auto") -> None: | ||||
@@ -79,7 +79,7 @@ class PaddleDataLoader(DataLoader): | |||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | return_list: bool = True, batch_sampler=None, | ||||
batch_size: int = 1, shuffle: bool = False, | |||||
batch_size: int = 1, shuffle: bool = True, | |||||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | ||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | |||||
from ..log import logger | from ..log import logger | ||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, | |||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
seed: int = 0, backend: str = 'auto'): | seed: int = 0, backend: str = 'auto'): | ||||
""" | """ | ||||
@@ -177,6 +177,7 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db, | def prepare_torch_dataloader(ds_or_db, | ||||
train_batch_size: int = 16, | train_batch_size: int = 16, | ||||
shuffle: bool = False, | shuffle: bool = False, | ||||
@@ -236,8 +237,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -251,7 +252,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
dl_bundle = [] | dl_bundle = [] | ||||
for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
if idx > 0: | if idx > 0: | ||||
train_batch_size = non_train_batch_size | |||||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||||
sampler = non_train_sampler if non_train_sampler else sampler | |||||
dl_bundle.append( | dl_bundle.append( | ||||
TorchDataLoader(dataset=ds, batch_size=train_batch_size, | TorchDataLoader(dataset=ds, batch_size=train_batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
@@ -276,8 +278,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -1,9 +1,10 @@ | |||||
from typing import Callable | |||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | "indice_collate_wrapper" | ||||
] | ] | ||||
def indice_collate_wrapper(func): | |||||
def indice_collate_wrapper(func:Callable): | |||||
""" | """ | ||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | ||||
@@ -1,6 +1,6 @@ | |||||
r""" | r""" | ||||
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | ||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 | |||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 | |||||
""" | """ | ||||
@@ -14,10 +14,10 @@ from fastNLP.core.utils.utils import pretty_table_printer | |||||
class Instance(Mapping): | class Instance(Mapping): | ||||
r""" | r""" | ||||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||||
Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。 | |||||
Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的 Example 所示:: | |||||
instance = Instance() # 请补充完整 | |||||
>>> instance = Instance(input="this is a demo sentence", label='good') # 请补充完整 | |||||
""" | """ | ||||
@@ -44,17 +44,17 @@ class Instance(Mapping): | |||||
def keys(self): | def keys(self): | ||||
r""" | r""" | ||||
返回一个迭代器,内容是field_name | |||||
返回一个迭代器,内容是 field_name | |||||
:return: 一个迭代器 | |||||
:return: 一个迭代器 | |||||
""" | """ | ||||
return self.fields.keys() | return self.fields.keys() | ||||
def values(self): | def values(self): | ||||
r""" | r""" | ||||
返回一个迭代器,内容是field_value | |||||
返回一个迭代器,内容是 field_value | |||||
:return: 一个迭代器 | |||||
:return: 一个迭代器 | |||||
""" | """ | ||||
return self.fields.values() | return self.fields.values() | ||||
@@ -1,7 +1,7 @@ | |||||
import os | import os | ||||
import signal | import signal | ||||
import sys | import sys | ||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple | |||||
from typing import Sequence, List, Optional, Callable, Dict, Union, Tuple | |||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from datetime import datetime | from datetime import datetime | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -19,13 +19,11 @@ class Driver(ABC): | |||||
r""" | r""" | ||||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | 用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | ||||
fastNLP 提供的 driver 实例都会同时被 Trainer 和 Evaluator 调用; | fastNLP 提供的 driver 实例都会同时被 Trainer 和 Evaluator 调用; | ||||
:param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 `torch.nn.DataParallel` 或者 | |||||
`torch.nn.parallel.DistributedDataParallel` 包裹过的模型; | |||||
""" | """ | ||||
def __init__(self, model): | def __init__(self, model): | ||||
r""" | |||||
:param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 `torch.nn.DataParallel` 或者 | |||||
`torch.nn.parallel.DistributedDataParallel` 包裹过的模型; | |||||
""" | |||||
self.model = model | self.model = model | ||||
# 这些属性用于 open_subprocess 和 on_exception 函数协同配合; | # 这些属性用于 open_subprocess 和 on_exception 函数协同配合; | ||||
@@ -36,24 +34,25 @@ class Driver(ABC): | |||||
def setup(self): | def setup(self): | ||||
r""" | r""" | ||||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | 该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | ||||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||||
多卡的 ``driver`` 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||||
""" | """ | ||||
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): | def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): | ||||
r""" | r""" | ||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 | |||||
:param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本; | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。 | |||||
其中 trainer 中 kwargs 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``dist``; | |||||
否则为 ``None``,evaluator 中的 kwargs 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``unrepeatdist``,否则为 ``None``; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load_checkpoint 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | ||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:param reproducible: 如果为 ``False``,不要做任何考虑;如果为 ``True``,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
该状态可以加载到一个全新的 dataloader 中然后恢复其状态; | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | ||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | ||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | ||||
@@ -65,50 +64,50 @@ class Driver(ABC): | |||||
def set_deterministic_dataloader(self, dataloader): | def set_deterministic_dataloader(self, dataloader): | ||||
r""" | r""" | ||||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 torch 的 dataloader,其 | |||||
需要将 worker_init_fn 替换; | |||||
为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 ``pytorch`` 的 ``dataloader``,其 | |||||
需要将 ``worker_init_fn`` 替换; | |||||
""" | """ | ||||
def set_sampler_epoch(self, dataloader, cur_epoch_idx): | def set_sampler_epoch(self, dataloader, cur_epoch_idx): | ||||
r""" | r""" | ||||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||||
对于分布式的 ``sampler``,例如 ``pytorch`` 的 ``DistributedSampler``,其需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的; | |||||
``dataloader`` 中可能真正发挥作用的是 ``batch_sampler`` 也可能是 ``sampler``。 | |||||
:param dataloader: 需要设置 epoch 的 dataloader 。 | |||||
:param cur_epoch_idx: 当前是第几个 epoch; | |||||
:param dataloader: 需要设置 ``epoch`` 的 ``dataloader``; | |||||
:param cur_epoch_idx: 当前是第几个 ``epoch``; | |||||
""" | """ | ||||
@abstractmethod | @abstractmethod | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
""" | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
r""" | |||||
通过调用 ``fn`` 来实现训练时的前向传播过程; | |||||
注意 ``Trainer`` 和 ``Evaluator`` 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 ``fn`` 是函数 ``get_model_call_fn`` 所返回的 | |||||
函数; | 函数; | ||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | ||||
:param fn: 调用该函数进行一次计算。 | :param fn: 调用该函数进行一次计算。 | ||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函 | |||||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
:param signature_fn: 由 ``Trainer`` 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 ``Dict`` 的时候,我们会自动调用 ``auto_param_call`` 函 | |||||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 ``DistributedDataParallel`` 的调用函数是 ``forward``,但是需要其函数签名为 ``model.module.forward``; | |||||
:return: 返回由 ``fn`` 返回的结果(应当为一个 ``dict`` 或者 ``dataclass``,但是不需要我们去检查); | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | ||||
@abstractmethod | @abstractmethod | ||||
def get_model_call_fn(self, fn: str) -> Tuple: | def get_model_call_fn(self, fn: str) -> Tuple: | ||||
""" | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
r""" | |||||
该函数会接受 ``Trainer`` 的 ``train_fn`` 或者 ``Evaluator`` 的 ``evaluate_fn``,返回一个实际用于调用 ``driver.model_call`` 时传入的函数参数; | |||||
该函数会在 ``Trainer`` 和 ``Evaluator`` 在 ``driver.setup`` 函数之后调用; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | ||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 ``train step`` 或者 ``evaluate step`` 是通过额外的参数 ``train_fn`` 和 | |||||
``evaluate_fn`` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 ``train step fn`` 和 | |||||
``evaluate step fn`` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
``evaluate step fn`` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
这一函数应当通过参数 ``fn`` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 ``fn``,则默认调用模型的 ``forward`` | |||||
函数,然后给出 warning; | 函数,然后给出 warning; | ||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 ``fn`` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | ||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | ||||
@@ -121,6 +120,9 @@ class Driver(ABC): | |||||
@property | @property | ||||
def model(self): | def model(self): | ||||
r""" | |||||
:return: 返回 driver 中在实际训练或者评测时所使用的模型; | |||||
""" | |||||
return self._model | return self._model | ||||
@model.setter | @model.setter | ||||
@@ -147,6 +149,9 @@ class Driver(ABC): | |||||
@property | @property | ||||
def model_device(self): | def model_device(self): | ||||
r""" | |||||
:return: 返回 driver 中模型实际所在的设备; | |||||
""" | |||||
return self._model_device | return self._model_device | ||||
@model_device.setter | @model_device.setter | ||||
@@ -155,28 +160,30 @@ class Driver(ABC): | |||||
@property | @property | ||||
def data_device(self): | def data_device(self): | ||||
""" | |||||
:return: 返回 driver 中数据默认会被迁移到的设备; | |||||
""" | |||||
return self.model_device | return self.model_device | ||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
""" | |||||
r""" | |||||
对于用户传入 trainer 的每一个 optimizer,检测其是否合理,因为不同的深度学习框架所使用的的 optimizer 是不相同的; | 对于用户传入 trainer 的每一个 optimizer,检测其是否合理,因为不同的深度学习框架所使用的的 optimizer 是不相同的; | ||||
:param optimizers: 需要检测的 `optimizers`; | :param optimizers: 需要检测的 `optimizers`; | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||||
raise NotImplementedError( | |||||
"Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||||
def set_optimizers(self, optimizers=None): | def set_optimizers(self, optimizers=None): | ||||
""" | |||||
r""" | |||||
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; | trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; | ||||
:param optimizers: | |||||
:return: | |||||
""" | """ | ||||
self.optimizers = optimizers | self.optimizers = optimizers | ||||
@abstractmethod | @abstractmethod | ||||
def backward(self, loss): | def backward(self, loss): | ||||
""" | |||||
r""" | |||||
实现深度学习中的反向传播过程; | 实现深度学习中的反向传播过程; | ||||
:param loss: 用来实现反向传播的损失函数值; | :param loss: 用来实现反向传播的损失函数值; | ||||
@@ -219,7 +226,7 @@ class Driver(ABC): | |||||
@property | @property | ||||
def auto_cast(self): | def auto_cast(self): | ||||
""" | |||||
r""" | |||||
fp16 的上下文环境; | fp16 的上下文环境; | ||||
:return: 返回一个用于 fp16 计算的上下文环境; | :return: 返回一个用于 fp16 计算的上下文环境; | ||||
@@ -246,7 +253,7 @@ class Driver(ABC): | |||||
r""" | r""" | ||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | 加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | ||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 BytesIO 对象; | |||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; | |||||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | :param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | ||||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | 模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | ||||
:return: 返回加载指定文件后的结果; | :return: 返回加载指定文件后的结果; | ||||
@@ -254,61 +261,65 @@ class Driver(ABC): | |||||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | ||||
@abstractmethod | @abstractmethod | ||||
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
def save_checkpoint(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, | |||||
**kwargs): | |||||
r""" | r""" | ||||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | 断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | ||||
:param folder: 保存断点重训的状态的文件夹;save 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||||
:param folder: 保存断点重训的状态的文件夹;save_checkpoint 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||||
FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件 | FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件 | ||||
中,将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | 中,将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | ||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | ||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | |||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load_checkpoint() 的时候,需要将 states 返回回去,load_checkpoint() 返回的值与这里的 | |||||
传入的值保持一致。 | 传入的值保持一致。 | ||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | ||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | ||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `save` function.") | |||||
raise NotImplementedError("Each specific driver should implemented its own `save_checkpoint` function.") | |||||
@abstractmethod | @abstractmethod | ||||
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
def load_checkpoint(self, folder: Union[str, Path], dataloader, only_state_dict: bool = True, should_load_model: bool = True, | |||||
**kwargs) -> Dict: | |||||
r""" | r""" | ||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | 断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | ||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||||
其它在 Driver.save_checkpoint() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save_checkpoint() 接受到的 states )。 | |||||
该函数应该在所有 rank 上执行。 | 该函数应该在所有 rank 上执行。 | ||||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | :param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | ||||
(如果 should_load_model 为True)。 | (如果 should_load_model 为True)。 | ||||
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||||
:param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||||
以及 'batch_idx_in_epoch' 这两个值。 | 以及 'batch_idx_in_epoch' 这两个值。 | ||||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | :param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | ||||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | ||||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | :param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | ||||
找到对应的模型状态,则报错。 | 找到对应的模型状态,则报错。 | ||||
:return: 需要返回 save 函数输入的 states 内容 | |||||
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||||
在保存与当前传入 data sample 数目不一致时报错。 | |||||
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `load` function.") | |||||
:return: 需要返回 save_checkpoint 函数输入的 states 内容 | |||||
* *dataloader* -- 返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||||
在保存与当前传入 data sample 数目不一致时报错。 | |||||
* *batch_idx_in_epoch* -- int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `load_checkpoint` function.") | |||||
@staticmethod | @staticmethod | ||||
def tensor_to_numeric(tensor, reduce: Optional[str]=None): | |||||
def tensor_to_numeric(tensor, reduce: Optional[str] = None): | |||||
r""" | r""" | ||||
将一个 `tensor` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||||
元素则返回 float 或 int 。 | |||||
将一个 ``tensor`` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 ``numeric`` 对象;如果 ``tensor`` 只包含一个 | |||||
元素则返回 ``float`` 或 ``int``; | |||||
:param tensor: 需要被转换的 `tensor` 对象 | |||||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||||
float 或 int 对象。 | |||||
:return: 转换后返回的结果 | |||||
:param tensor: 需要被转换的 `tensor` 对象; | |||||
:param reduce: 可选 ``['sum', 'max', 'mea', 'min']``,如果不为 ``None`` 将使用该 ``reduce`` 方法来处理当前 ``tensor`` 再返回 | |||||
``float`` 或 ``int`` 对象; | |||||
:return: 转换后返回的结果; | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | ||||
@@ -321,7 +332,7 @@ class Driver(ABC): | |||||
""" | """ | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
""" | |||||
r""" | |||||
保证用户拿到的模型一定是最原始的模型; | 保证用户拿到的模型一定是最原始的模型; | ||||
注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是 | 注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是 | ||||
最为原始的模型; | 最为原始的模型; | ||||
@@ -342,14 +353,14 @@ class Driver(ABC): | |||||
@abstractmethod | @abstractmethod | ||||
def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构; | |||||
:return: 将移动到指定机器上的 batch 对象返回; | :return: 将移动到指定机器上的 batch 对象返回; | ||||
""" | """ | ||||
def get_local_rank(self) -> int: | def get_local_rank(self) -> int: | ||||
r""" | r""" | ||||
返回当前的local_rank,本函数的返回值只在运行分布式训练的时候有实际含义。 | |||||
返回当前的local_rank,本函数的返回值只在运行分布式训练的时候有实际含义; | |||||
:return: 一个整数值,表示当前进程在当前这台机器上的序号; | :return: 一个整数值,表示当前进程在当前这台机器上的序号; | ||||
""" | """ | ||||
@@ -358,13 +369,13 @@ class Driver(ABC): | |||||
def barrier(self): | def barrier(self): | ||||
r""" | r""" | ||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | 用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | ||||
仅在多分布式训练场景中有使用。 | |||||
仅在多分布式训练场景中有使用; | |||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier; | |||||
""" | """ | ||||
def is_distributed(self) -> bool: | def is_distributed(self) -> bool: | ||||
""" | |||||
r""" | |||||
当前的 driver 实例是否是分布式的; | 当前的 driver 实例是否是分布式的; | ||||
:return: 返回一个 bool 值,如果当前的 driver 实例是用于分布式的,那么返回 True; | :return: 返回一个 bool 值,如果当前的 driver 实例是用于分布式的,那么返回 True; | ||||
@@ -372,7 +383,7 @@ class Driver(ABC): | |||||
return False | return False | ||||
def on_exception(self): | def on_exception(self): | ||||
""" | |||||
r""" | |||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | 该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | ||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | ||||
@@ -390,40 +401,38 @@ class Driver(ABC): | |||||
'exc_local_rank': self.get_local_rank(), | 'exc_local_rank': self.get_local_rank(), | ||||
} | } | ||||
sys.stderr.write("\nException info:\n") | sys.stderr.write("\nException info:\n") | ||||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||||
sys.stderr.write(json.dumps(_write_exc_info, indent=2) + "\n") | |||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | ||||
for pid in self._pids: | for pid in self._pids: | ||||
if pid != os.getpid(): | if pid != os.getpid(): | ||||
os.kill(pid, signal.SIGKILL) | os.kill(pid, signal.SIGKILL) | ||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||||
""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )broadcast 到其它所有进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
def broadcast_object(self, obj, src: int = 0, group=None, **kwargs): | |||||
r""" | |||||
从 ``src`` 端将 ``obj`` 对象(可能是 ``tensor``,可能是 ``object`` )broadcast 到其它所有进程。如果是非 ``tensor`` 的对象会尝试使用 ``pickle`` 进行打包进行 | |||||
传输,然后再 ``dst`` 处再加载回来。仅在分布式的 ``driver`` 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
:param int src: source 的 global rank 。 | |||||
:param group: 所属的 group | |||||
:param kwargs: | |||||
:return: 输入的 obj 。 | |||||
:param obj: obj,可能是 ``Tensor`` 或 嵌套类型的数据; | |||||
:param src: source 的 ``global rank``; | |||||
:param group: 所属的通信组; | |||||
:return: 输入的 ``obj``; | |||||
""" | """ | ||||
if not self.is_distributed(): | if not self.is_distributed(): | ||||
return obj | return obj | ||||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `broadcast_object` method right " | raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `broadcast_object` method right " | ||||
f"now.") | f"now.") | ||||
def all_gather(self, obj, group)->List: | |||||
""" | |||||
def all_gather(self, obj, group) -> List: | |||||
r""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | ||||
pickle 进行序列化,接收到之后再反序列化。 | pickle 进行序列化,接收到之后再反序列化。 | ||||
:param obj: 可以是 float/int/bool/np.ndarray/{}/[]/Tensor等。 | |||||
:param group: | |||||
:return: 返回值应该是 [obj0, obj1, ...], 其中obj1是rank0上的对象,obj1是rank1上的对象... | |||||
:param obj: 可以是 ``float/int/bool/np.ndarray/{}/[]/Tensor`` 等; | |||||
:param group: 用于不同进程之间互相通信的通信组; | |||||
:return: 返回值应该是 ``[obj0, obj1, ...]``,其中 ``obj1`` 是 ``rank0`` 上的对象,``obj1`` 是 ``rank1`` 上的对象; | |||||
""" | """ | ||||
if not self.is_distributed(): | if not self.is_distributed(): | ||||
return [obj] | return [obj] | ||||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `all_gather` method right " | raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `all_gather` method right " | ||||
f"now.") | f"now.") | ||||
@@ -21,6 +21,9 @@ if _NEED_IMPORT_JITTOR: | |||||
'sum': jt.sum | 'sum': jt.sum | ||||
} | } | ||||
__all__ = [ | |||||
"JittorDriver", | |||||
] | |||||
class JittorDriver(Driver): | class JittorDriver(Driver): | ||||
r""" | r""" | ||||
@@ -90,9 +93,6 @@ class JittorDriver(Driver): | |||||
"'test_step'.") | "'test_step'.") | ||||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | ||||
""" | |||||
保存模型 | |||||
""" | |||||
if model_save_fn is not None: | if model_save_fn is not None: | ||||
outputs = model_save_fn(filepath) | outputs = model_save_fn(filepath) | ||||
if outputs is not None: | if outputs is not None: | ||||
@@ -105,20 +105,14 @@ class JittorDriver(Driver): | |||||
jt.save(states, filepath) | jt.save(states, filepath) | ||||
def load_model(self, filepath: str): | def load_model(self, filepath: str): | ||||
""" | |||||
加载模型的加载函数; | |||||
:param file_path: 保存文件的文件位置(需要包括文件名); | |||||
:return: 加载后的state_dict | |||||
""" | |||||
if not os.path.exists(filepath): | if not os.path.exists(filepath): | ||||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | ||||
return jt.load(filepath) | return jt.load(filepath) | ||||
def save(self): | |||||
def save_checkpoint(self): | |||||
... | ... | ||||
def load(self): | |||||
def load_checkpoint(self): | |||||
... | ... | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -156,7 +150,7 @@ class JittorDriver(Driver): | |||||
def move_data_to_device(self, batch: 'jt.Var'): | def move_data_to_device(self, batch: 'jt.Var'): | ||||
""" | """ | ||||
jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch | |||||
**jittor** 暂时没有提供数据迁移的函数,因此这个函数只是简单地返回 **batch** | |||||
""" | """ | ||||
return batch | return batch | ||||
@@ -20,6 +20,10 @@ class JittorMPIDriver(JittorDriver): | |||||
这是一个正在开发中的功能,敬请期待。 | 这是一个正在开发中的功能,敬请期待。 | ||||
.. todo: | |||||
实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
@@ -9,7 +9,7 @@ __all__ = [] | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||||
用于仿造的 **GradScaler** 对象,防止重复写大量的if判断 | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
pass | pass | ||||
@@ -21,6 +21,8 @@ if _NEED_IMPORT_PADDLE: | |||||
_parse_load_result, | _parse_load_result, | ||||
) | ) | ||||
__all__ = [] | |||||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | def _validate_output_list_for_rank(my_rank, dst, gather_list): | ||||
if dst == my_rank: | if dst == my_rank: | ||||
if not gather_list: | if not gather_list: | ||||
@@ -1,3 +1,69 @@ | |||||
r""" | |||||
用于实现 **PaddlePaddle** 框架下使用 ``fleet`` 分布式训练 API 进行集群式(*collective*)多卡训练的 Driver。 | |||||
.. note:: | |||||
在 **PaddlePaddle** 框架中,使用分布式训练的方式可以参见 **PaddlePaddle** 的 | |||||
`官方文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/06_distributed_training/cluster_quick_start_cn.html>`_ 。 | |||||
简言之,分布式训练的过程可以概括为:导入 ``fleet`` 包 -> 使用 :func:`fleet.init` 初始化分布式环境 -> 初始化模型,转换为并行模型开始训练。 | |||||
**fastNLP** 支持三种启动分布式训练的方式(假设执行训练的文件名为 ``train.py``): | |||||
A. 用户自己不进行分布式的任何操作,直接使用我们的 :class:`~fastNLP.core.Trainer` 进行训练,此时将参数 ``device`` | |||||
设置为一个列表,然后使用 ``python train.py`` 的方式开始训练; | |||||
B. 用户自己不进行分布式的任何操作,但是使用 ``python -m paddle.distributed.launch train.py`` 开始训练; | |||||
C. 用户自己在外面初始化分布式环境,并且通过 ``python -m paddle.distributed.launch train.py`` 开始训练; | |||||
.. note:: | |||||
在后两种启动方式中,您需要通过参数 ``--gpus`` 来指定训练使用的设备,在 ``trainer`` 中设置的参数是无效的。 | |||||
不过在使用该 Driver 之前,我们需要向您说明 **fastNLP** 实现 ``PaddleFleetDriver`` 的思路,以便于您理解代码编写过程中可能出现的问题。 | |||||
在 **fastNLP** 中,为了尽可能减少单卡向分布式训练转换过程中的代码变动,我们需要在 ``PaddleFleetDriver`` 中进行 **分布式环境初始化** | |||||
和 **将模型转换为并行模式** 等操作,同时实现多卡训练的方法是从主进程(``rank=0``)中创建其它的所有子进程(``rank=1,2,...``)。 | |||||
在这个过程中,我们发现由于 **PaddlePaddle** 框架的特性,会出现下面的问题: | |||||
1. **fastNLP** 中,初始化模型一定会在初始化 ``Driver`` 之前,因此调用 :func:`fleet.init` 的时机会在初始化模型之后; | |||||
此时子进程中模型将无法正常地初始化,提示无法找到设备 ``gpu:0``; | |||||
2. 在训练的过程中,会出现训练一个 ``batch`` 后程序卡住或程序会占用所有可见显卡的情况; | |||||
考虑到这些问题,我们为 **PaddlePaddle** 的分布式训练制定了这样的约束:在导入 **fastNLP** 之前,必须设置环境变量 ``FASTNLP_BACKEND`` | |||||
为 ``paddle``。执行方法有两种:: | |||||
>>> import os | |||||
>>> os.environ["FASTNLP_BACKEND"] = "paddle" # 设置环境变量 | |||||
>>> import fastNLP # 设置之后才可以导入 fastNLP | |||||
或是在执行脚本(假设文件名为 ``train.py`` )时设置:: | |||||
FASTNLP_BACKEND=paddle python train.py | |||||
FASTNLP_BACKEND=paddle python -m paddle.distributed.lauch train.py | |||||
设置 ``FASTNLP_BACKEND=paddle`` 后,**fastNLP** 会在 ``import paddle`` 之前通过 ``CUDA_VISIBLE_DEVICES`` 将设备限制在所有可见设备的第 | |||||
**0** 张卡上,以此绕开通信和同步上的种种限制。我们会将用户希望可见的设备(如用户自己设置了 ``CUDA_VISIBLE_DEVICES`` 的情况)保存在另一个环境变量 | |||||
``USER_CUDA_VISIBLE_DEVICES`` 中来确保 **fastNLP** 能够知道用户的设置。假设用户希望在 ``[0,2,3]`` 三张显卡上进行分布式训练,那么在三个训练进程中, | |||||
``CUDA_VISIBLE_DEVICES`` 就分别为 0、2 和 3 。 | |||||
.. note:: | |||||
我们会事先将设备限制在所有可见设备的第 **0** 张卡上,因此多卡训练的参数 ``device`` 一定要以 **0** 开始,否则会无法正常地启动。 | |||||
如果您希望调整使用的第一张显卡,请使用 ``CUDA_VISIBLE_DEVICES`` 进行限制。 | |||||
.. note:: | |||||
根据 **PaddlePaddle** 的说明,设置 ``CUDA_VISIBLE_DEVICES`` 之后启动分布式训练时,情况A与情况BC设置设备的方式会有所不同。 | |||||
情况A应设置为实际设备相对可见设备的索引,而情况BC应设置为实际的设备号: | |||||
1. 情况A中, ``CUDA_VISIBLE_DEVICES=3,4,5,6`` 且参数 ``device=[0,2,3]`` 代表使用 **3号、5号和6号** 显卡; | |||||
2. 情况BC中,``CUDA_VISIBLE_DEVICES=3,4,5,6`` 且参数 ``--gpu=3,5,6`` 代表使用 **3号、5号和6号** 显卡; | |||||
.. note:: | |||||
多机的启动强制要求用户在每一台机器上使用 ``python -m paddle.distributed.launch`` 启动;因此我们不会在 ``PaddleFleetDriver`` | |||||
中保存任何当前有多少台机器的信息; | |||||
""" | |||||
import os | import os | ||||
from typing import List, Union, Optional, Dict, Tuple, Callable | from typing import List, Union, Optional, Dict, Tuple, Callable | ||||
@@ -53,6 +119,33 @@ __all__ = [ | |||||
] | ] | ||||
class PaddleFleetDriver(PaddleDriver): | class PaddleFleetDriver(PaddleDriver): | ||||
""" | |||||
:param model: 训练使用的模型; | |||||
* 如果不想自己初始化分布式环境,类型应为 :class:`paddle.nn.Layer`; | |||||
* 如果已经在外面初始化了分布式环境,类型应为 :class:`paddle.DataParallel`; | |||||
:param parallel_device: 多卡训练时使用的设备,必须是一个列表。 | |||||
当使用 ``python -m paddle.distributed.launch`` 启动时,该参数无效; | |||||
:param is_pull_by_paddle_run: 标记当前进程是否为通过 ``python -m paddle.distributed.launch`` 启动的。 | |||||
这个参数仅在 :class:`~fastNLP.core.Trainer` 中初始化 driver 时使用 | |||||
:param fp16: 是否开启混合精度训练; | |||||
:kwargs: | |||||
* *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: | |||||
* fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | |||||
* is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况; | |||||
* role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` | |||||
* 其它用于初始化 ``DataParallel`` 的参数; | |||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | |||||
.. note:: | |||||
关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | |||||
""" | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
model, | model, | ||||
@@ -61,143 +154,20 @@ class PaddleFleetDriver(PaddleDriver): | |||||
fp16: bool = False, | fp16: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
r""" | |||||
通过使用 PaddlePaddle 的 Fleet 框架启动多卡进程的 Driver。 | |||||
需要注意的一点是,由于 PaddlePaddle 框架的特性,如果直接使用在 rank0 拉起其它进程的方法的话,如果不加以任何限制,PaddlePaddle会出现 | |||||
第一次前向传播后卡住或占用所有显卡的现象;为了解决这一问题,我们在引入 FastNLP 时,会使用 `CUDA_VISIBLE_DEVICES` 将设备限制在卡0上, | |||||
而用户如果使用了这一环境变量,我们会将其储存在 `USER_CUDA_VISIBLE_DEVICES` 中,并且通过一定的手段实现了转换(详细的设置请参见: | |||||
`fastNLP/envs/set_backend.py`)。在拉起其它进程的时候,我们会如法炮制,将环境限制在对应的设备上。 | |||||
`PaddleFleetDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行分布式的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `FleetLauncher` 拉起多个进程, | |||||
然后 `PaddleFleetDriver` 自己通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化分布式训练,但是用户自己使用 python -m paddle.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化分布式,并且通过 python -m paddle.distributed.launch 拉起,这时无论是多个进程的拉起和通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `PaddleFleetDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;因此我们不会在 `PaddleFleetDriver` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`PaddleFleetDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DataParallel` 包裹的model), | |||||
因为 `Parallel` 的使用一定要求 fleet.init 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 paddle.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `FleetLauncher.launch` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `fleet.init` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `FleetLauncher.launch`;进程 0 运行到 `fleet.init`,paddle 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DataParallel` 将模型包裹; | |||||
至此,paddle 分布式的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 paddle.distributed.launch 拉起,并且没有自己建立分布式的通信组。这时在 | |||||
`PaddleFleetDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `CUDA_VISIBLE_DEVICE` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现,我们会通过一定的手段将其保存起来); | |||||
剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 paddle.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `PaddleFleetDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `PaddleFleetDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `PaddleFleetDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(fleet 已经被初始化); | |||||
parallel_helper._is_parallel_ctx_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m paddle.distributed.launch` 拉起还是由我们的 `FleetLauncher.launch()` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'PADDLE_RANK_IN_NODE'、'PADDLE_TRAINER_ID' | |||||
以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m paddle.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 PaddleFleetDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 paddle.distributed.launch 拉起 | | |||||
|(或者我们自己的 FleetLauncher 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 paddle.distributed.launch 拉起 | 我们自己的 FleetLauncher 函数拉起多个进程 | |||||
_____________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 fleet | | | |||||
↓ ————————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立分布式所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 fleet 所 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 fleet.init函数 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | 用户自己调用 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 PaddleFleetDriver | | | | |||||
的 world_size 和 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | PaddleFleetDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `PaddleFleetDriver` 运行时所需要的环境变量分为两种,一种是 paddle fleet 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `PaddleFleetDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | | | |||||
| device 决定,必须是一个list, | 为 CUDA_VISIBLE_DEVICES | 为 CUDA_VISIBLE_DEVICES | |||||
| 其中每一个对象都是 int | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DataParallel` 的forward 函数来帮助 | |||||
我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DataParallel` 的 forward 方法, | |||||
然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 forward 函数,还是 | |||||
`train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`PaddleFleetDriver` 的处理; | |||||
不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | |||||
if USER_CUDA_VISIBLE_DEVICES not in os.environ: | if USER_CUDA_VISIBLE_DEVICES not in os.environ: | ||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | ||||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
# 如果不是通过 launch 启动,要求用户必须传入 parallel_device | # 如果不是通过 launch 启动,要求用户必须传入 parallel_device | ||||
if not is_pull_by_paddle_run and parallel_device is None: | |||||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
if not is_pull_by_paddle_run: | |||||
if parallel_device is None: | |||||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
if not isinstance(parallel_device, List): | |||||
raise ValueError("Parameter `parallel_device`'s type must be List when using `PaddleFleetDriver`, " | |||||
f"not {type(parallel_device)}.") | |||||
if get_paddle_device_id(parallel_device[0]) != 0: | |||||
raise ValueError("The first device of `parallel_device` must be 'gpu:0' in fastNLP.") | |||||
# 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | ||||
# 这个参数会在 initialize_paddle_drvier 中设置。 | # 这个参数会在 initialize_paddle_drvier 中设置。 | ||||
@@ -254,10 +224,10 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def setup(self): | def setup(self): | ||||
""" | """ | ||||
根据不同的情况进行不同的设置。 | |||||
1、如果是通过 paddle.distributed.launch 方法启动时,则根据已经设置好的环境获取 | |||||
分布式的属性。 | |||||
2、否则,调用 FleetLauncher 类启动子进程 | |||||
初始化分布式训练的环境。 | |||||
1. 如果是通过 ``paddle.distributed.launch`` 方法启动的,则根据已经设置好的环境获取分布式的属性。 | |||||
2. 否则启动子进程。 | |||||
""" | """ | ||||
if self._has_setup: | if self._has_setup: | ||||
return | return | ||||
@@ -267,7 +237,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
if self.outside_fleet: | if self.outside_fleet: | ||||
# 已经初始化了多机环境 | # 已经初始化了多机环境 | ||||
self.set_from_fleet_environment() | |||||
self._set_from_fleet_environment() | |||||
else: | else: | ||||
# 用户没有初始化多机环境 | # 用户没有初始化多机环境 | ||||
# TODO 绕一下 | # TODO 绕一下 | ||||
@@ -287,7 +257,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# parallel_device 是 list, | # parallel_device 是 list, | ||||
if not parallel_helper._is_parallel_ctx_initialized(): | if not parallel_helper._is_parallel_ctx_initialized(): | ||||
# 拉起子进程并设置相应的属性 | # 拉起子进程并设置相应的属性 | ||||
self.init_fleet_and_set() | |||||
self._init_fleet_and_set() | |||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | ||||
else: | else: | ||||
# 已经设置过一次,保证参数必须是一样的 | # 已经设置过一次,保证参数必须是一样的 | ||||
@@ -321,7 +291,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | ||||
self._pids = self.tensor_to_numeric(self._pids) | self._pids = self.tensor_to_numeric(self._pids) | ||||
def init_fleet_and_set(self): | |||||
def _init_fleet_and_set(self): | |||||
""" | """ | ||||
使用 FleetLauncher 拉起子进程 | 使用 FleetLauncher 拉起子进程 | ||||
""" | """ | ||||
@@ -340,7 +310,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
assert self.world_size is not None | assert self.world_size is not None | ||||
assert self.world_size == len(self.parallel_device) | assert self.world_size == len(self.parallel_device) | ||||
def set_from_fleet_environment(self): | |||||
def _set_from_fleet_environment(self): | |||||
""" | """ | ||||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | ||||
根据 paddle 设置的环境变量来获得各种属性 | 根据 paddle 设置的环境变量来获得各种属性 | ||||
@@ -349,19 +319,11 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.global_rank = paddledist.get_rank() | self.global_rank = paddledist.get_rank() | ||||
def barrier(self): | def barrier(self): | ||||
r""" | |||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||||
仅在多分布式训练场景中有使用。 | |||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
paddledist.barrier() | paddledist.barrier() | ||||
def configure_fleet(self): | def configure_fleet(self): | ||||
""" | |||||
将模型用 DataParallel 和自定义的类型包裹起来 | |||||
""" | |||||
# 将模型用 DataParallel 和自定义的类型包裹起来 | |||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | ||||
self.model = DataParallel( | self.model = DataParallel( | ||||
_FleetWrappingModel(self.model), | _FleetWrappingModel(self.model), | ||||
@@ -395,10 +357,17 @@ class PaddleFleetDriver(PaddleDriver): | |||||
@property | @property | ||||
def model_device(self): | def model_device(self): | ||||
""" | |||||
:return: 模型所在的设备; | |||||
""" | |||||
return self._model_device | return self._model_device | ||||
@property | @property | ||||
def data_device(self): | def data_device(self): | ||||
""" | |||||
:return: 数据所在的设备;由于 **PaddlePaddle** 可以通过环境变量获取当前进程的设备,因此该属性 | |||||
和 ``model_device`` 表现相同; | |||||
""" | |||||
return self.model_device | return self.model_device | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
@@ -440,7 +409,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist.set_distributed( | dist.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -522,23 +491,29 @@ class PaddleFleetDriver(PaddleDriver): | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
def is_global_zero(self): | |||||
def is_global_zero(self) -> bool: | |||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
def get_model_no_sync_context(self): | def get_model_no_sync_context(self): | ||||
return self.model.no_sync | return self.model.no_sync | ||||
def unwrap_model(self): | |||||
def unwrap_model(self) -> "paddle.nn.Layer": | |||||
""" | |||||
获得 driver 最原始的模型。该函数可以取出被 :class:`paddle.DataParallel` 包裹起来的模型。 | |||||
""" | |||||
_layers = self.model._layers | _layers = self.model._layers | ||||
if isinstance(_layers, _FleetWrappingModel): | if isinstance(_layers, _FleetWrappingModel): | ||||
return _layers.model | return _layers.model | ||||
else: | else: | ||||
return _layers | return _layers | ||||
def get_local_rank(self) ->int: | |||||
def get_local_rank(self) -> int: | |||||
return self.local_rank | return self.local_rank | ||||
def is_distributed(self): | |||||
def is_distributed(self) -> bool: | |||||
""" | |||||
判断是否为分布式的 **Driver** ,在 ``PaddleFleetDriver`` 中,返回 ``True``。 | |||||
""" | |||||
return True | return True | ||||
@staticmethod | @staticmethod | ||||
@@ -40,8 +40,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if user_visible_devices is None: | if user_visible_devices is None: | ||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | ||||
if device is not None: | if device is not None: | ||||
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||||
"up your script. And we will directly get the local device via environment variables.") | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||||
"up your script. And we will directly get the local device via environment variables.", once=True) | |||||
_visible_list = user_visible_devices.split(",") | _visible_list = user_visible_devices.split(",") | ||||
device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | ||||
# TODO 目前一个进程仅对应一个卡,所以暂时传入单个 | # TODO 目前一个进程仅对应一个卡,所以暂时传入单个 | ||||
@@ -1,14 +1,12 @@ | |||||
import os | import os | ||||
import random | import random | ||||
from typing import Union, Optional, Dict | |||||
from typing import Union, Optional, Dict, Any | |||||
from pathlib import Path | from pathlib import Path | ||||
from functools import partial | from functools import partial | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | |||||
from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler | from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
@@ -50,9 +48,26 @@ if _NEED_IMPORT_PADDLE: | |||||
class PaddleDriver(Driver): | class PaddleDriver(Driver): | ||||
r""" | r""" | ||||
Paddle框架的Driver,包括实现单卡训练的 `PaddleSingleDriver` 和分布式训练的 `PaddleFleetDriver`。 | |||||
实现了 **PaddlePaddle** 框架训练功能的基本 Driver,实现了单卡和多卡情景下均需要实现的功能,以和 **fastNLP** 的 | |||||
:class:`~fastNLP.core.Trainer` 兼容;通过这个 Driver,可以在 **fastNLP** 中实现从 **Pytorch** 框架到 | |||||
**PaddlePaddle** 深度学习框架的切换。 | |||||
这个类被以下子类继承: | |||||
1. :class:`~fastNLP.core.drivers.PaddleSingleDriver`:实现了使用单卡和 ``cpu`` 训练的具体功能; | |||||
2. :class:`~fastNLP.core.drivers.PaddleFleetDriver`:实现了使用 ``fleet`` 分布式训练 API 进行集群式分布式训练的具体功能; | |||||
:param model: 训练时使用的 **PaddlePaddle** 模型; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:kwargs: | |||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | |||||
.. note:: | |||||
关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | |||||
""" | """ | ||||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, **kwargs): | |||||
if not isinstance(model, paddle.nn.Layer): | if not isinstance(model, paddle.nn.Layer): | ||||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " | raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " | ||||
f"`paddle.nn.Layer` type.") | f"`paddle.nn.Layer` type.") | ||||
@@ -69,10 +84,10 @@ class PaddleDriver(Driver): | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
r""" | r""" | ||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||||
注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积; | |||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 ``optimizers`` 来将梯度置零; | |||||
注意梯度累积不需要在这里实现,:class:`~fastNLP.core.Trainer` 已经在内部实现了梯度累积; | |||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | |||||
:param set_to_none: 用来判断是否需要将梯度直接置为 ``None``;在 **PaddlePaddle** 中这个参数无效。 | |||||
""" | """ | ||||
for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
optimizer.clear_grad() | optimizer.clear_grad() | ||||
@@ -87,14 +102,6 @@ class PaddleDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
r""" | |||||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 | |||||
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 | |||||
:param dataloader: 需要检测的输入的 `dataloader`; | |||||
:param dataloader_name: | |||||
:param is_train: | |||||
""" | |||||
if is_train: | if is_train: | ||||
if not isinstance(dataloader, DataLoader): | if not isinstance(dataloader, DataLoader): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.") | ||||
@@ -164,16 +171,15 @@ class PaddleDriver(Driver): | |||||
@rank_zero_call | @rank_zero_call | ||||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | ||||
r""" | r""" | ||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||||
将模型保存到 ``filepath`` 中。 | |||||
:param filepath: 保存文件的文件位置(需要包括文件名); | :param filepath: 保存文件的文件位置(需要包括文件名); | ||||
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 | |||||
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 | |||||
:param kwargs: | |||||
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 | |||||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||||
的文档: | |||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||||
:param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save`` | |||||
函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数; | |||||
:kwargs: | |||||
* input_spec -- 描述存储模型 ``forward`` 方法的输入; | |||||
当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` | |||||
进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save>`_; | |||||
""" | """ | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if isinstance(filepath, Path): | if isinstance(filepath, Path): | ||||
@@ -189,14 +195,6 @@ class PaddleDriver(Driver): | |||||
paddle.jit.save(model, filepath, input_spec) | paddle.jit.save(model, filepath, input_spec) | ||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | ||||
r""" | |||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | |||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | |||||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | |||||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | |||||
:return: 返回加载指定文件后的结果; | |||||
""" | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if isinstance(filepath, Path): | if isinstance(filepath, Path): | ||||
filepath = str(filepath) | filepath = str(filepath) | ||||
@@ -209,7 +207,28 @@ class PaddleDriver(Driver): | |||||
model.load_dict(paddle.load(filepath)) | model.load_dict(paddle.load(filepath)) | ||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | |||||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||||
:param folder: 保存断点重训的状态的文件夹;save 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||||
FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件中, | |||||
将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存该对象即可, | |||||
Driver 应该不需要理解该对象,同时在 driver.load_checkpoint() 的时候,需要将 states 返回回去,load() 返回的值与这里的传入的值保持一致。 | |||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||||
:kwargs: | |||||
* input_spec -- 描述存储模型 ``forward`` 方法的输入; | |||||
当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor`` | |||||
进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save>`_; | |||||
.. todo: | |||||
等 Driver 的文档写完 | |||||
""" | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | ||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | ||||
@@ -278,7 +297,7 @@ class PaddleDriver(Driver): | |||||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | ||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | ||||
@@ -352,37 +371,41 @@ class PaddleDriver(Driver): | |||||
r""" | r""" | ||||
返回一个不计算梯度的环境用来对模型进行评测; | 返回一个不计算梯度的环境用来对模型进行评测; | ||||
:return: context 上下文对象 `paddle.no_grad`; | |||||
:return: 上下文对象 ``paddle.no_grad``; | |||||
""" | """ | ||||
return paddle.no_grad | return paddle.no_grad | ||||
@staticmethod | @staticmethod | ||||
def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): | def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): | ||||
r""" | r""" | ||||
用来将模型转移到指定的 device 上; | |||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
用来将模型 ``model`` 转移到指定的设备上; | |||||
.. note:: | |||||
在 **Paddle** 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
:param model: 需要进行转移的模型; | |||||
:param device: 目标设备; | |||||
""" | """ | ||||
if device is not None: | if device is not None: | ||||
model.to(device) | model.to(device) | ||||
def move_data_to_device(self, batch: "paddle.Tensor"): | |||||
def move_data_to_device(self, batch: Any) -> Any: | |||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
将数据集合 ``batch`` 迁移到指定的机器上。 | |||||
.. note:: | |||||
在 **Paddle** 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
:return: 将移动到指定机器上的 batch 对象返回; | |||||
:param batch: 包含 :class:`paddle.Tensor` 的数据集合,可以是 **List**、**Dict** 等嵌套类型; | |||||
:return: 移动到指定机器后的 `batch``; | |||||
""" | """ | ||||
device = _convert_data_device(self.data_device) | device = _convert_data_device(self.data_device) | ||||
return paddle_move_data_to_device(batch, device) | return paddle_move_data_to_device(batch, device) | ||||
@staticmethod | @staticmethod | ||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | ||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||||
with ``seed_everything(seed, workers=True)``. | |||||
See also the PyTorch documentation on | |||||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | |||||
""" | |||||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | ||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | ||||
# TODO gpu | # TODO gpu | ||||
@@ -409,9 +432,6 @@ class PaddleDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def get_dataloader_args(dataloader: "DataLoader"): | def get_dataloader_args(dataloader: "DataLoader"): | ||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | @dataclass | ||||
class Res: | class Res: | ||||
@@ -33,9 +33,20 @@ __all__ = [ | |||||
class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
""" | """ | ||||
支持 paddle cpu 或单卡 gpu 训练的 driver | |||||
实现了 **PaddlePaddle** 框架下在单卡或 ``cpu`` 环境下训练功能的 **Driver**。 | |||||
:param model: 训练时使用的 **PaddlePaddle** 模型; | |||||
:param device: 训练使用的设备; | |||||
:param fp16: 是否开启混合精度训练; | |||||
:kwargs: | |||||
* wo_auto_param_call (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为; | |||||
.. note:: | |||||
关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。 | |||||
""" | """ | ||||
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | |||||
if isinstance(model, DataParallel): | if isinstance(model, DataParallel): | ||||
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | ||||
@@ -62,7 +73,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
def setup(self): | def setup(self): | ||||
r""" | r""" | ||||
该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。 | |||||
初始化训练环境;设置当前训练的设备,并将模型迁移到对应设备上。 | |||||
""" | """ | ||||
device = _convert_data_device(self.data_device) | device = _convert_data_device(self.data_device) | ||||
@@ -95,7 +106,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleSampler): | elif isinstance(dist, ReproducibleSampler): | ||||
@@ -127,17 +138,20 @@ class PaddleSingleDriver(PaddleDriver): | |||||
return dataloader | return dataloader | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
if isinstance(self.model, paddle.DataParallel): | |||||
return self.model._layers | |||||
else: | |||||
return self.model | |||||
""" | |||||
返回训练使用的模型。 | |||||
""" | |||||
return self.model | |||||
@property | @property | ||||
def data_device(self): | |||||
def data_device(self) -> str: | |||||
""" | """ | ||||
返回数据所在的设备。由于单卡模式不支持 data_device,因此返回的是 model_device | |||||
:return: 数据和模型所在的设备; | |||||
""" | """ | ||||
return self.model_device | return self.model_device | ||||
def is_distributed(self): | |||||
def is_distributed(self) -> bool: | |||||
""" | |||||
判断是否为分布式的 **Driver** ,在 ``PaddleSingleDriver`` 中,返回 ``False``。 | |||||
""" | |||||
return False | return False |
@@ -31,7 +31,15 @@ __all__ = [ | |||||
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: | def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: | ||||
return random.randint(min_seed_value, max_seed_value) | return random.randint(min_seed_value, max_seed_value) | ||||
def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | |||||
def paddle_seed_everything(seed: Optional[int], workers: bool = False) -> int: | |||||
r""" | |||||
为 **paddle**、**numpy**、**python.random** 伪随机数生成器设置种子。 | |||||
:param seed: 全局随机状态的整数值种子。如果为 ``None``,将从环境变量 ``FASTNLP_GLOBAL_SEED`` 中读取种子或随机选择; | |||||
:param workers: 如果为 ``True`` ,则会设置环境变量 ``FASTNLP_SEED_WORKERS`` 。该环境变量会在 :class:`~fastNLP.core.Trainer` | |||||
中配置 ``dataloader`` 时用于设置 ``worker_init_fn`` 。如果用户已经为 ``dataloader`` 提供了 ``worker_init_fn`` ,则设置 | |||||
此参数将没有影响; | |||||
""" | |||||
max_seed_value = np.iinfo(np.uint32).max | max_seed_value = np.iinfo(np.uint32).max | ||||
min_seed_value = np.iinfo(np.uint32).min | min_seed_value = np.iinfo(np.uint32).min | ||||
@@ -70,7 +78,7 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
def reset_seed() -> None: | def reset_seed() -> None: | ||||
""" | """ | ||||
fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | |||||
``fleet`` 会开启多个进程,因此当用户在脚本中指定 ``seed_everything`` 时,在开启多个脚本后,会在每个脚本内重新 | |||||
进行随机数的设置; | 进行随机数的设置; | ||||
""" | """ | ||||
seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | ||||
@@ -80,8 +88,8 @@ def reset_seed() -> None: | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
参考 _DDPWrappingModel , paddle 的分布式训练也需要用 paddle.nn.DataParallel 进行包装,采用和 | |||||
pytorch 相似的处理方式 | |||||
参考 :class:`fastNLP.core.drivers.torch_driver.utils._DDPWrappingModel` , **PaddlePaddle** 的分布式训练也需要用 :class:`paddle.nn.DataParallel` 进行包装,采用和 | |||||
**pytorch** 相似的处理方式 | |||||
""" | """ | ||||
def __init__(self, model: 'nn.Layer'): | def __init__(self, model: 'nn.Layer'): | ||||
super(_FleetWrappingModel, self).__init__() | super(_FleetWrappingModel, self).__init__() | ||||
@@ -100,7 +108,7 @@ class _FleetWrappingModel(Layer): | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||||
用于仿造的 **GradScaler** 对象,防止重复写大量的if判断 | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
pass | pass | ||||
@@ -144,7 +152,7 @@ def _build_fp16_env(dummy=False): | |||||
def find_free_ports(num): | def find_free_ports(num): | ||||
""" | """ | ||||
在空闲的端口中找到 num 个端口 | |||||
在空闲的端口中找到 ``num`` 个端口 | |||||
""" | """ | ||||
def __free_port(): | def __free_port(): | ||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | ||||
@@ -174,8 +182,8 @@ def find_free_ports(num): | |||||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | ||||
""" | """ | ||||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 | |||||
考虑了用户自己定制了 DataLoader 的情形。 | |||||
利用 ``batch_sampler`` 重新构建一个 ``DataLoader``,起到替换 ``batch_sampler`` 又不影响原 ``dataloader`` 的作用。 | |||||
考虑了用户自己定制了 ``DataLoader`` 的情形。 | |||||
""" | """ | ||||
# 拿到非下划线开头的实例属性; | # 拿到非下划线开头的实例属性; | ||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
@@ -246,7 +254,7 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
def replace_sampler(dataloader, new_sampler): | def replace_sampler(dataloader, new_sampler): | ||||
""" | """ | ||||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 | |||||
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 | |||||
""" | """ | ||||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | new_batch_sampler = deepcopy(dataloader.batch_sampler) | ||||
new_batch_sampler.sampler = new_sampler | new_batch_sampler.sampler = new_sampler | ||||
@@ -1,3 +1,130 @@ | |||||
r""" | |||||
""" | |||||
r""" | |||||
`TorchDDPDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程, | |||||
然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存 | |||||
任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量); | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model), | |||||
因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹; | |||||
至此,ddp 的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在 | |||||
`TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化); | |||||
dist.is_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 TorchDDPDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 torch.distributed.launch 拉起 | | |||||
|(或者我们自己的 open_subprocess 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程 | |||||
___________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 ddp | | | |||||
↓ ——————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 dist. | | | | |||||
init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用 | |||||
_group 函数 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 TorchDDPDriver | | | | |||||
的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device( | |||||
| device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}") | |||||
| 其中每一个对象都是 torch.device | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | |||||
import os | import os | ||||
import sys | import sys | ||||
import __main__ | import __main__ | ||||
@@ -7,6 +134,7 @@ from time import sleep | |||||
from typing import List, Optional, Union, Dict, Tuple, Callable | from typing import List, Optional, Union, Dict, Tuple, Callable | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
@@ -26,7 +154,8 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
) | ) | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | from fastNLP.core.utils import auto_param_call, check_user_specific_params | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \ | |||||
ReproducibleBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | ||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -34,6 +163,81 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||||
class TorchDDPDriver(TorchDriver): | class TorchDDPDriver(TorchDriver): | ||||
r""" | |||||
``TorchDDPDriver`` 通过开启多个进程,让每个进程单独使用一个 gpu 设备来实现分布式训练; | |||||
.. note:: | |||||
您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; | |||||
``TorchDDPDriver`` 目前支持的三种启动方式: | |||||
1. 用户自己不进行 ``ddp`` 的任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程, | |||||
然后 ``TorchDDPDriver`` 自己通过调用 ``dist.init_process_group`` 来初始化 ddp 的通信组;(情况 A) | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
... | |||||
driver='torch', | |||||
device=[0, 1] | |||||
) | |||||
trainer.run() | |||||
通过运行 ``python train.py`` 启动; | |||||
2. 用户同样不在 ``Trainer`` 之外初始化 ``ddp``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 ``dist.init_process_group`` 来初始化 ``ddp`` 的通信组;(情况 B) | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
... | |||||
driver='torch', | |||||
device=None | |||||
) | |||||
trainer.run() | |||||
通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
3. 用户自己在外面初始化 ``DDP``,并且通过 ``python -m torch.distributed.launch`` 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||||
都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``TorchDDPDriver`` 设置一些必要的属性值;(情况 C) | |||||
.. code-block:: | |||||
import torch.distributed as dist | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
# 获取当前的进程信息; | |||||
... | |||||
# 初始化 ddp 不同进程间的通信组; | |||||
dist.init_process_group(...) | |||||
# 初始化模型使用 DistributedDataParallel 包裹; | |||||
model = Model() | |||||
model = DistributedDataParallel(model, ...) | |||||
# 注意此时仍旧不需要您主动地将 datalaoder 的 sampler 替换为 DistributedSampler; | |||||
trainer = Trainer( | |||||
... | |||||
driver='torch', | |||||
device=None | |||||
) | |||||
trainer.run() | |||||
通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
注意多机的启动强制要求用户在每一台机器上使用 ``python -m torch.distributed.launch`` 启动;因此我们不会在 ``TorchDDPDriver`` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param parallel_device: 用于分布式训练的 ``gpu`` 设备; | |||||
:param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; | |||||
:param fp16: 是否开启 fp16 训练; | |||||
:param kwargs: 其余的一些用于设定 ddp 训练的参数; | |||||
""" | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
model, | model, | ||||
@@ -42,129 +246,7 @@ class TorchDDPDriver(TorchDriver): | |||||
fp16: bool = False, | fp16: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
r""" | |||||
`TorchDDPDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程, | |||||
然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存 | |||||
任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量); | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model), | |||||
因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹; | |||||
至此,ddp 的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在 | |||||
`TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化); | |||||
dist.is_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 TorchDDPDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 torch.distributed.launch 拉起 | | |||||
|(或者我们自己的 open_subprocess 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程 | |||||
___________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 ddp | | | |||||
↓ ——————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 dist. | | | | |||||
init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用 | |||||
_group 函数 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 TorchDDPDriver | | | | |||||
的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device( | |||||
| device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}") | |||||
| 其中每一个对象都是 torch.device | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | |||||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | # 在加入很多东西后,需要注意这里调用 super 函数的位置; | ||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
@@ -176,8 +258,9 @@ class TorchDDPDriver(TorchDriver): | |||||
self.is_pull_by_torch_run = is_pull_by_torch_run | self.is_pull_by_torch_run = is_pull_by_torch_run | ||||
self.parallel_device = parallel_device | self.parallel_device = parallel_device | ||||
if not is_pull_by_torch_run and parallel_device is None: | if not is_pull_by_torch_run and parallel_device is None: | ||||
raise ValueError("Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
raise ValueError( | |||||
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | ||||
if is_pull_by_torch_run: | if is_pull_by_torch_run: | ||||
@@ -233,10 +316,16 @@ class TorchDDPDriver(TorchDriver): | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | os.makedirs(name=self.output_from_new_proc, exist_ok=True) | ||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | ||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | ||||
def setup(self): | def setup(self): | ||||
r""" | |||||
准备分布式环境,该函数主要做以下两件事情: | |||||
1. 开启多进程,每个 gpu 设备对应单独的一个进程; | |||||
2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型; | |||||
""" | |||||
if self._has_setup: | if self._has_setup: | ||||
return | return | ||||
self._has_setup = True | self._has_setup = True | ||||
@@ -280,9 +369,10 @@ class TorchDDPDriver(TorchDriver): | |||||
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; | # 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; | ||||
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | ||||
if pre_num_processes != len(self.parallel_device): | if pre_num_processes != len(self.parallel_device): | ||||
raise RuntimeError("Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||||
"allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||||
"`num_nodes` and `num_processes`.") | |||||
raise RuntimeError( | |||||
"Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||||
"allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||||
"`num_nodes` and `num_processes`.") | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
@@ -302,7 +392,7 @@ class TorchDDPDriver(TorchDriver): | |||||
local_world_size = local_world_size.tolist() + 1 | local_world_size = local_world_size.tolist() + 1 | ||||
node_rank = self.global_rank // local_world_size | node_rank = self.global_rank // local_world_size | ||||
self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | |||||
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||||
self._pids = self.tensor_to_numeric(self._pids) | self._pids = self.tensor_to_numeric(self._pids) | ||||
def configure_ddp(self): | def configure_ddp(self): | ||||
@@ -423,9 +513,10 @@ class TorchDDPDriver(TorchDriver): | |||||
return self.model, model.forward | return self.model, model.forward | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | ||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist.set_distributed( | dist.set_distributed( | ||||
@@ -505,16 +596,26 @@ class TorchDDPDriver(TorchDriver): | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||||
raise ValueError( | |||||
"Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
r""" | |||||
:return: 返回当前的进程是否在全局上是进程 0 ; | |||||
""" | |||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
def get_model_no_sync_context(self): | def get_model_no_sync_context(self): | ||||
r""" | |||||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步; | |||||
""" | |||||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | # 注意此时的 model 是 "DistributedDataParallel" 对象; | ||||
return self.model.no_sync | return self.model.no_sync | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
r""" | |||||
:return: 返回没有经过 ``DistributedDataParallel`` 包裹的原始模型; | |||||
""" | |||||
_module = self.model.module | _module = self.model.module | ||||
if isinstance(_module, _DDPWrappingModel): | if isinstance(_module, _DDPWrappingModel): | ||||
return _module.model | return _module.model | ||||
@@ -522,17 +623,26 @@ class TorchDDPDriver(TorchDriver): | |||||
return _module | return _module | ||||
def get_local_rank(self) -> int: | def get_local_rank(self) -> int: | ||||
r""" | |||||
:return: 返回当前进程局部的进程编号; | |||||
""" | |||||
return self.local_rank | return self.local_rank | ||||
def barrier(self): | def barrier(self): | ||||
r""" | |||||
通过使用该函数来使得各个进程之间同步操作; | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
torch.distributed.barrier(async_op=False) | torch.distributed.barrier(async_op=False) | ||||
def is_distributed(self): | def is_distributed(self): | ||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchDDPDriver`` 来说,该函数一定返回 ``True``; | |||||
""" | |||||
return True | return True | ||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||||
""" | |||||
def broadcast_object(self, obj, src: int = 0, group=None, **kwargs): | |||||
r""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | ||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | ||||
@@ -540,7 +650,6 @@ class TorchDDPDriver(TorchDriver): | |||||
:param int src: source 的 global rank 。 | :param int src: source 的 global rank 。 | ||||
:param int dst: target 的 global rank,可以是多个目标 rank | :param int dst: target 的 global rank,可以是多个目标 rank | ||||
:param group: 所属的 group | :param group: 所属的 group | ||||
:param kwargs: | |||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | ||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | ||||
""" | """ | ||||
@@ -549,7 +658,7 @@ class TorchDDPDriver(TorchDriver): | |||||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | ||||
def all_gather(self, obj, group) -> List: | def all_gather(self, obj, group) -> List: | ||||
""" | |||||
r""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | ||||
pickle 进行序列化,接收到之后再反序列化。 | pickle 进行序列化,接收到之后再反序列化。 | ||||
@@ -578,10 +687,9 @@ class TorchDDPDriver(TorchDriver): | |||||
def find_free_network_port() -> str: | def find_free_network_port() -> str: | ||||
"""Finds a free port on localhost. | |||||
It is useful in single-node training when we don't want to connect to a real master node but have to set the | |||||
`MASTER_PORT` environment variable. | |||||
""" | |||||
在 localhost 上找到一个空闲端口; | |||||
当我们不想连接到真正的主节点但必须设置“MASTER_PORT”环境变量时在单节点训练中很有用; | |||||
""" | """ | ||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
s.bind(("", 0)) | s.bind(("", 0)) | ||||
@@ -145,6 +145,27 @@ def _tensor_to_object(tensor, tensor_size): | |||||
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | ||||
r""" | |||||
pytorch 中的单点对多点的分发函数; | |||||
例如将进程 0 上的对象 object 分发到其它进程上; | |||||
Example:: | |||||
cur_rank = int(os.environ.get('LOCAL_RANK', 0)) | |||||
# 拿到 local_device | |||||
send_recv_object(object, 0, cur_rank, local_device) | |||||
:param obj: 一个可以序列化的 python 对象; | |||||
:param src: 从哪一个 rank 上发送到其它 rank; | |||||
:param cur_rank: 当前的进程的 rank 序号; | |||||
:param device: 当前的进程所在的设备; | |||||
:param group: 通信组,默认为 None; | |||||
:param tag: 将发送与远程接收匹配的标记; | |||||
:return: | |||||
""" | |||||
# src rank send to all other ranks | # src rank send to all other ranks | ||||
size = torch.LongTensor([0]).to(device) | size = torch.LongTensor([0]).to(device) | ||||
@@ -26,13 +26,13 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
# world_size 和 rank | # world_size 和 rank | ||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
if device is not None: | if device is not None: | ||||
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
"up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
"`os.environ['LOCAL_RANK']`.") | |||||
"`os.environ['LOCAL_RANK']`.", once=True) | |||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
if driver not in {"torch", "fairscale"}: | if driver not in {"torch", "fairscale"}: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | |||||
_could_use_device_num = torch.cuda.device_count() | _could_use_device_num = torch.cuda.device_count() | ||||
if isinstance(device, str): | if isinstance(device, str): | ||||
@@ -43,6 +43,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | ||||
elif device >= _could_use_device_num: | elif device >= _could_use_device_num: | ||||
print(device, _could_use_device_num) | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
else: | else: | ||||
device = torch.device(f"cuda:{device}") | device = torch.device(f"cuda:{device}") | ||||
@@ -1,11 +1,13 @@ | |||||
import os | import os | ||||
from typing import Dict, Union, Callable, Tuple, Optional | from typing import Dict, Union, Callable, Tuple, Optional | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.nn import DataParallel | from torch.nn import DataParallel | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | from torch.utils.data import RandomSampler as TorchRandomSampler | ||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -15,15 +17,25 @@ from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | from fastNLP.core.samplers import RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
class TorchSingleDriver(TorchDriver): | class TorchSingleDriver(TorchDriver): | ||||
r""" | r""" | ||||
用于 cpu 和 单卡 gpu 运算; | |||||
``TorchSingleDriver`` 是用于 cpu 和 单卡 gpu 运算的 ``driver``; | |||||
.. note:: | |||||
如果您希望使用 ``DataParallel`` 来训练您的模型,您应当自己在 ``Trainer`` 初始化之前初始化好 ``DataParallel``,然后将其传入 ``Trainer`` 中; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param device: torch.device,当前进程所使用的设备; | |||||
:param fp16: 是否开启 fp16; | |||||
""" | """ | ||||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | ||||
if isinstance(model, DistributedDataParallel): | if isinstance(model, DistributedDataParallel): | ||||
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | ||||
@@ -51,6 +63,9 @@ class TorchSingleDriver(TorchDriver): | |||||
self.world_size = 1 | self.world_size = 1 | ||||
def setup(self): | def setup(self): | ||||
r""" | |||||
将模型迁移到相应的设备上; | |||||
""" | |||||
if self.model_device is not None: | if self.model_device is not None: | ||||
self.model.to(self.model_device) | self.model.to(self.model_device) | ||||
@@ -88,10 +103,11 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleSampler): | elif isinstance(dist, ReproducibleSampler): | ||||
@@ -108,21 +124,31 @@ class TorchSingleDriver(TorchDriver): | |||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, TorchRandomSampler): | if isinstance(args.sampler, TorchRandomSampler): | ||||
# 如果本来就是随机的,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(args.sampler, TorchSequentialSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
r""" | |||||
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; | |||||
""" | |||||
if isinstance(self.model, torch.nn.DataParallel) or \ | if isinstance(self.model, torch.nn.DataParallel) or \ | ||||
isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | ||||
return self.model.module | return self.model.module | ||||
@@ -131,16 +157,13 @@ class TorchSingleDriver(TorchDriver): | |||||
@property | @property | ||||
def data_device(self): | def data_device(self): | ||||
""" | |||||
单卡模式不支持 data_device; | |||||
r""" | |||||
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``; | |||||
""" | """ | ||||
return self.model_device | return self.model_device | ||||
def is_distributed(self): | def is_distributed(self): | ||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``; | |||||
""" | |||||
return False | return False | ||||
@@ -36,7 +36,17 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
r""" | r""" | ||||
专属于 pytorch 的 driver;因为我们会在同一个 Trainer 框架下提供 jittor、paddle 等训练框架的支持; | |||||
专属于 ``pytorch`` 的 ``driver``,是 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 的父类; | |||||
.. warning:: | |||||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``TorchSingleDriver`` 和 ``TorchDDPDriver``,而不是 | |||||
该类本身; | |||||
.. note:: | |||||
您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口; | |||||
""" | """ | ||||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | ||||
super(TorchDriver, self).__init__(model) | super(TorchDriver, self).__init__(model) | ||||
@@ -111,7 +121,15 @@ class TorchDriver(Driver): | |||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
@staticmethod | @staticmethod | ||||
def tensor_to_numeric(tensor, reduce=None): | |||||
def tensor_to_numeric(tensor, reduce: str = None): | |||||
r""" | |||||
将 ``torch.Tensor`` 转换成 python 中的数值类型; | |||||
:param tensor: ``torch.Tensor``; | |||||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||||
""" | |||||
if tensor is None: | if tensor is None: | ||||
return None | return None | ||||
@@ -129,6 +147,10 @@ class TorchDriver(Driver): | |||||
) | ) | ||||
def set_model_mode(self, mode: str): | def set_model_mode(self, mode: str): | ||||
r""" | |||||
设置模型的状态是 ``train`` 还是 ``eval``; | |||||
:param mode: ``train`` 或者 ``eval``; | |||||
""" | |||||
assert mode in {"train", "eval"} | assert mode in {"train", "eval"} | ||||
getattr(self.model, mode)() | getattr(self.model, mode)() | ||||
@@ -179,7 +201,7 @@ class TorchDriver(Driver): | |||||
model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | ||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | ||||
@@ -253,7 +275,7 @@ class TorchDriver(Driver): | |||||
states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
@@ -297,7 +319,7 @@ class TorchDriver(Driver): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | sampler = RandomSampler(dataloader_args.sampler.data_source) | ||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | ||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = ReproduceBatchSampler( | sampler = ReproduceBatchSampler( | ||||
@@ -326,14 +348,26 @@ class TorchDriver(Driver): | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
r""" | |||||
:return: 返回 ``torch.no_grad`` 这个 context; | |||||
""" | |||||
return torch.no_grad | return torch.no_grad | ||||
@staticmethod | @staticmethod | ||||
def move_model_to_device(model: "torch.nn.Module", device: "torch.device"): | def move_model_to_device(model: "torch.nn.Module", device: "torch.device"): | ||||
r""" | |||||
将模型迁移到对应的设备上; | |||||
""" | |||||
if device is not None: | if device is not None: | ||||
model.to(device) | model.to(device) | ||||
def move_data_to_device(self, batch: "torch.Tensor"): | |||||
def move_data_to_device(self, batch): | |||||
""" | |||||
将一个 batch 的数据迁移到对应的设备上; | |||||
:param batch: 一个 batch 的数据,可以是 ``list、dict`` 等; | |||||
:return: | |||||
""" | |||||
return torch_move_data_to_device(batch, self.data_device, self.non_blocking) | return torch_move_data_to_device(batch, self.data_device, self.non_blocking) | ||||
@staticmethod | @staticmethod | ||||
@@ -174,7 +174,7 @@ def _build_fp16_env(dummy=False): | |||||
def replace_sampler(dataloader: "DataLoader", sampler): | def replace_sampler(dataloader: "DataLoader", sampler): | ||||
""" | |||||
r""" | |||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | 替换 sampler (初始化一个新的 dataloader 的逻辑在于): | ||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | ||||
@@ -259,7 +259,7 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
def _dataloader_init_kwargs_resolve_sampler( | def _dataloader_init_kwargs_resolve_sampler( | ||||
dataloader: "DataLoader", sampler: Optional["Sampler"] | dataloader: "DataLoader", sampler: Optional["Sampler"] | ||||
) -> Dict[str, Any]: | ) -> Dict[str, Any]: | ||||
""" | |||||
r""" | |||||
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | 此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | ||||
""" | """ | ||||
batch_sampler = getattr(dataloader, "batch_sampler") | batch_sampler = getattr(dataloader, "batch_sampler") | ||||
@@ -279,15 +279,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||||
def replace_batch_sampler(dataloader, new_batch_sampler): | def replace_batch_sampler(dataloader, new_batch_sampler): | ||||
"""Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new | |||||
dataloader with new batch sampler. | |||||
Args: | |||||
dataloader: input dataloader | |||||
new_batch_sampler: new batch sampler to use | |||||
Returns: | |||||
DataLoader | |||||
r""" | |||||
替换一个 dataloader 的 batch_sampler; | |||||
""" | """ | ||||
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | ||||
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | ||||
@@ -296,12 +289,16 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | params = {k: getattr(dataloader, k) for k in params_keys} | ||||
params["batch_sampler"] = new_batch_sampler | params["batch_sampler"] = new_batch_sampler | ||||
return type(dataloader)(**params) | return type(dataloader)(**params) | ||||
# TODO 这里是否可以auto_param_call一下 | |||||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||||
# signature_fn=type(dataloader).__init__) | |||||
def optimizer_state_to_device(state, device): | def optimizer_state_to_device(state, device): | ||||
r""" | |||||
将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备; | |||||
:param state: ``optimzier.state_dict()``; | |||||
:param device: 要迁移到的目的设备; | |||||
:return: 返回迁移后的新的 state_dict; | |||||
""" | |||||
new_state = {} | new_state = {} | ||||
for name, param in state.items(): | for name, param in state.items(): | ||||
if isinstance(param, dict): | if isinstance(param, dict): | ||||
@@ -3,7 +3,7 @@ import subprocess | |||||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | ||||
""" | |||||
r""" | |||||
使用 command 通过 subprocess.Popen 开启新的进程。 | 使用 command 通过 subprocess.Popen 开启新的进程。 | ||||
:param output_from_new_proc: 可选 ["ignore", "all", "only_error"],以上三个为特殊关键字,分别表示完全忽略拉起进程的打印输出, | :param output_from_new_proc: 可选 ["ignore", "all", "only_error"],以上三个为特殊关键字,分别表示完全忽略拉起进程的打印输出, | ||||
@@ -11,8 +11,8 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy: | |||||
两个文件,名称分别为 {rank}_std.log, {rank}_err.log 。原有的文件会被直接覆盖。 | 两个文件,名称分别为 {rank}_std.log, {rank}_err.log 。原有的文件会被直接覆盖。 | ||||
:param command: List[str] 启动的命令 | :param command: List[str] 启动的命令 | ||||
:param env_copy: 需要注入的环境变量。 | :param env_copy: 需要注入的环境变量。 | ||||
:param rank: | |||||
:return: | |||||
:param rank: global_rank; | |||||
:return: 返回使用 ``subprocess.Popen`` 打开的进程; | |||||
""" | """ | ||||
if output_from_new_proc == "all": | if output_from_new_proc == "all": | ||||
proc = subprocess.Popen(command, env=env_copy) | proc = subprocess.Popen(command, env=env_copy) | ||||
@@ -1,11 +1,12 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
] | ] | ||||
from .metric import Metric | from .metric import Metric | ||||
from .accuracy import Accuracy | |||||
from .accuracy import Accuracy, TransformersAccuracy | |||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric |
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Accuracy' | |||||
'Accuracy', | |||||
"TransformersAccuracy" | |||||
] | ] | ||||
from typing import Union | from typing import Union | ||||
@@ -17,9 +18,9 @@ class Accuracy(Metric): | |||||
""" | """ | ||||
计算 准确率 的 metric 。 | 计算 准确率 的 metric 。 | ||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
:param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | ||||
""" | """ | ||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
@@ -39,11 +40,11 @@ class Accuracy(Metric): | |||||
r""" | r""" | ||||
update 函数将针对一个批次的预测结果做评价指标的累计 | update 函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
:param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
:param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | ||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
:param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
如果mask也被传进来的话seq_len会被忽略. | 如果mask也被传进来的话seq_len会被忽略. | ||||
""" | """ | ||||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | ||||
@@ -79,3 +80,20 @@ class Accuracy(Metric): | |||||
else: | else: | ||||
self.total += np.prod(list(pred.shape)).item() | self.total += np.prod(list(pred.shape)).item() | ||||
self.correct += (target == pred).sum().item() | self.correct += (target == pred).sum().item() | ||||
class TransformersAccuracy(Accuracy): | |||||
""" | |||||
适配 transformers 中相关模型的 Accuracy metric 。 | |||||
""" | |||||
def update(self, logits, labels, attention_mask=None): | |||||
r""" | |||||
update 函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。 | |||||
:param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]`` | |||||
:param attention_mask: 序列长度标记。 | |||||
""" | |||||
seq_len = attention_mask.sum(dim=-1) | |||||
super().update(pred=logits, target=labels, seq_len=seq_len) |
@@ -14,14 +14,16 @@ from fastNLP.core.metrics.element import Element | |||||
class Metric: | class Metric: | ||||
""" | |||||
fastNLP 中 Metric 的基类,自定义 Metric 时,请继承该对象。使用该对象,将有助于减少在分布式状态下的 Metric 计算。 | |||||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | |||||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||||
进行自动设置。 | |||||
""" | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None): | ||||
""" | |||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | |||||
self.backend = AutoBackend(backend) | self.backend = AutoBackend(backend) | ||||
self._updated = False | self._updated = False | ||||
self.get_metric = self._sync_get_metric(self.get_metric) | self.get_metric = self._sync_get_metric(self.get_metric) | ||||
@@ -39,7 +41,10 @@ class Metric: | |||||
""" | """ | ||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | ||||
tensor 直接进行加减乘除计算即可。 | tensor 直接进行加减乘除计算即可。 | ||||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | |||||
..warning:: | |||||
如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | |||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | ||||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | :param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | ||||
@@ -200,27 +200,27 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
class SpanFPreRecMetric(Metric): | class SpanFPreRecMetric(Metric): | ||||
r""" | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||||
:param target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||||
:param seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||||
:param encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||||
:param ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||||
:param only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||||
:param f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, auto]``。其中 ``auto`` 表示根据实际调用 | |||||
Metric.update() 函数时传入的参数决定具体的 ``backend`` ,大部分情况下直接使用 ``auto`` 即可。 | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 :class:`fastNLP.Evaluator` 中根据 sampler 是否使用分布式 | |||||
进行自动设置。 | |||||
""" | |||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
only_gross: bool = True, f_type='micro', | only_gross: bool = True, f_type='micro', | ||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: | beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: | ||||
r""" | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | |||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
@@ -10,7 +10,7 @@ def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
:param sampler: | |||||
:param sampler: 需要转换的 sampler 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | ||||
@@ -55,16 +55,16 @@ class ReproducibleBatchSampler: | |||||
class ReproduceBatchSampler(ReproducibleBatchSampler): | class ReproduceBatchSampler(ReproducibleBatchSampler): | ||||
""" | |||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||||
:param batch_size: 每个 batch 的大小是多少。 | |||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||||
:param kwargs: fastNLP 内部使用。 | |||||
""" | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | |||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||||
:param batch_size: 每个 batch 的大小是多少。 | |||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||||
:param kwargs: fastNLP 内部使用。 | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
self.batch_sampler = batch_sampler | self.batch_sampler = batch_sampler | ||||
@@ -158,18 +158,18 @@ class ReproduceBatchSampler(ReproducibleBatchSampler): | |||||
class RandomBatchSampler(ReproducibleBatchSampler): | class RandomBatchSampler(ReproducibleBatchSampler): | ||||
""" | |||||
随机分 batch 的 batch_sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | ||||
drop_last: bool = False, seed: int = 0, **kwargs): | drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | |||||
随机分 batch 的 batch_sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
@@ -363,28 +363,28 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
""" | |||||
首先按照 ``sample`` 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,``sample`` 只会在这个桶内进行组 | |||||
合,这样每个 ``batch`` 中的 ``padding`` 数量会比较少 (因为桶内的数据的长度都接近)。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param num_batch_per_bucket: 多少个 ``batch`` 组成一个桶,数据只会在一个桶内进行 ``shuffle`` 。 | |||||
:param shuffle: 如果为 True,将不进行 ``shuffle``,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 `batch` 的 ``sample`` 数量无法凑齐 ``batch_size`` 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | |||||
首先按照 ``sample`` 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,``sample`` 只会在这个桶内进行组 | |||||
合,这样每个 ``batch`` 中的 ``padding`` 数量会比较少 (因为桶内的数据的长度都接近)。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param num_batch_per_bucket: 多少个 ``batch`` 组成一个桶,数据只会在一个桶内进行 ``shuffle`` 。 | |||||
:param shuffle: 如果为 True,将不进行 ``shuffle``,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 `batch` 的 ``sample`` 数量无法凑齐 ``batch_size`` 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
@@ -53,15 +53,15 @@ class ReproducibleSampler: | |||||
class RandomSampler(ReproducibleSampler): | class RandomSampler(ReproducibleSampler): | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||||
""" | |||||
随机顺序的 Sampler 。 | |||||
""" | |||||
随机顺序的 Sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器 | |||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | |||||
:param seed: 随机数种子。 | |||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | |||||
""" | |||||
:param dataset: 实现了 __len__ 方法的数据容器 | |||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | |||||
:param seed: 随机数种子。 | |||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | |||||
""" | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||||
super(RandomSampler, self).__init__() | super(RandomSampler, self).__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -213,13 +213,13 @@ class RandomSampler(ReproducibleSampler): | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
def __init__(self, dataset, **kwargs): | |||||
""" | |||||
按照顺序读取 ``dataset`` 。在多卡情况下,间隔读取,例如,在两卡情况下,卡 0 取 ``[0,2,4,..]``, 卡1取 ``[1,3,5...]`` 。 | |||||
""" | |||||
按照顺序读取 ``dataset`` 。在多卡情况下,间隔读取,例如,在两卡情况下,卡 0 取 ``[0,2,4,..]``, 卡1取 ``[1,3,5...]`` 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
def __init__(self, dataset, **kwargs): | |||||
super().__init__(dataset=dataset, **kwargs) | super().__init__(dataset=dataset, **kwargs) | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -283,23 +283,23 @@ class SequentialSampler(RandomSampler): | |||||
class SortedSampler(SequentialSampler): | class SortedSampler(SequentialSampler): | ||||
""" | |||||
将 ``dataset`` 中的数据根据 ``length`` 从长到短进行迭代。在多卡情况下,由于 ``padding`` , 最后一个 ``sample`` 可能是最长 | |||||
的那个 ``sample`` 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param seed: 设置的随机数种子。 | |||||
:param kwargs: fastNLP 保留使用。 | |||||
""" | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | def __init__(self, dataset, length:Union[str, List], **kwargs): | ||||
""" | |||||
将 ``dataset`` 中的数据根据 ``length`` 从长到短进行迭代。在多卡情况下,由于 ``padding`` , 最后一个 ``sample`` 可能是最长 | |||||
的那个 ``sample`` 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param seed: 设置的随机数种子。 | |||||
:param kwargs: fastNLP 保留使用。 | |||||
""" | |||||
super().__init__(dataset=dataset, **kwargs) | super().__init__(dataset=dataset, **kwargs) | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
@@ -19,15 +19,15 @@ class UnrepeatedSampler: | |||||
class UnrepeatedRandomSampler(UnrepeatedSampler): | class UnrepeatedRandomSampler(UnrepeatedSampler): | ||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||||
""" | |||||
考虑在多卡evaluate的场景下,不能重复sample。 | |||||
""" | |||||
考虑在多卡 evaluate 的场景下,不能重复 sample。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = seed | self.seed = seed | ||||
@@ -96,16 +96,22 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | ||||
""" | |||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||||
batch 数量不完全一致。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | def __init__(self, dataset, length:Union[str, List], **kwargs): | ||||
""" | |||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||||
batch 数量不完全一致。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
@@ -125,13 +131,13 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | ||||
def __init__(self, dataset, **kwargs): | |||||
""" | |||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
""" | |||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
def __init__(self, dataset, **kwargs): | |||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -32,8 +32,8 @@ def is_jittor_dataset(dataset) -> bool: | |||||
def jittor_collate_wraps(func, auto_collator: Callable): | def jittor_collate_wraps(func, auto_collator: Callable): | ||||
""" | """ | ||||
对 ``jittor`` 的 ``collate_fn`` 进行 ``wrap`` 封装,。如果数据集为 ``mapping`` 类型,那么采用 ``auto_collator`` ,否则 | |||||
还是采用 ``jittor`` 的 ``collate_batch``。 | |||||
对 ``jittor`` 的 ``collate_fn`` 进行 ``wrap`` 封装,。如果数据集为 ``mapping`` 类型,那么采用 ``auto_collator`` , | |||||
否则还是采用 ``jittor`` 的 ``collate_batch``。 | |||||
:param func: | :param func: | ||||
:param auto_collator: | :param auto_collator: | ||||
@@ -61,8 +61,8 @@ def _convert_data_device(device: Union[str, int]) -> str: | |||||
def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor": | def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor": | ||||
""" | """ | ||||
将 ``data`` 迁移到指定的 ``device`` 上。``paddle.Tensor`` 没有类似 ``torch.Tensor`` 的 ``to`` 函数,该函数 | |||||
只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | |||||
将 ``data`` 迁移到指定的 ``device`` 上。``paddle.Tensor`` 没有类似 ``torch.Tensor`` 的 ``to`` 函数, | |||||
该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | |||||
:param data: 要迁移的张量; | :param data: 要迁移的张量; | ||||
:param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; | :param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; | ||||
@@ -130,8 +130,8 @@ def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> | |||||
将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 | 将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 | ||||
:param batch: 需要进行迁移的数据集合; | :param batch: 需要进行迁移的数据集合; | ||||
:param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数 | |||||
为 `None`` 时,不会执行任何操作。 | |||||
:param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串; | |||||
当这个参数为 `None`` 时,不会执行任何操作。 | |||||
:return: 迁移到新设备上的数据集合; | :return: 迁移到新设备上的数据集合; | ||||
""" | """ | ||||
if device is None: | if device is None: | ||||
@@ -1,6 +1,6 @@ | |||||
""" | """ | ||||
该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 | |||||
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 | |||||
该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` | |||||
中的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 | |||||
""" | """ | ||||
import sys | import sys | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
@@ -44,11 +44,6 @@ class DummyFRichProgress: | |||||
return True | return True | ||||
class FRichProgress(Progress, metaclass=Singleton): | class FRichProgress(Progress, metaclass=Singleton): | ||||
""" | |||||
fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 | |||||
""" | |||||
def new_progess(self, *columns: Union[str, ProgressColumn], | def new_progess(self, *columns: Union[str, ProgressColumn], | ||||
console: Optional[Console] = None, | console: Optional[Console] = None, | ||||
auto_refresh: bool = True, | auto_refresh: bool = True, | ||||
@@ -60,7 +60,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
``value`` 的参数。 | ``value`` 的参数。 | ||||
1. 该函数用来提供给用户根据字符串匹配从而实现自动调用; | 1. 该函数用来提供给用户根据字符串匹配从而实现自动调用; | ||||
2. 注意 ``mapping`` 默认为 ``None``,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 ``mapping`` 为一个字典传入进来; | |||||
2. 注意 ``mapping`` 默认为 ``None``,如果您希望指定输入和运行函数的参数的对应方式,那么您应当让 ``mapping`` 为一个字典传入进来; | |||||
如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性; | 如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性; | ||||
3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | 3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | ||||
4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同; | 4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同; | ||||
@@ -82,8 +82,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | :param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | ||||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数; | :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数; | ||||
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||||
参数值后,再传给 ``fn`` 进行实际的运算; | |||||
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名, | |||||
然后通过该函数签名提取参数值后,再传给 ``fn`` 进行实际的运算; | |||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | :param mapping: 一个字典,用来更改其前面的字典的键值; | ||||
:return: 返回 ``fn`` 运行的结果; | :return: 返回 ``fn`` 运行的结果; | ||||
@@ -142,7 +142,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
if _name not in _has_params: | if _name not in _has_params: | ||||
_has_params[_name] = _value | _has_params[_name] = _value | ||||
if len(_has_params)<len(_need_params): | |||||
if len(_has_params) < len(_need_params): | |||||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | ||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | ||||
_provided_keys = _get_keys(args) | _provided_keys = _get_keys(args) | ||||
@@ -195,8 +195,8 @@ def _get_fun_msg(fn, with_fp=True)->str: | |||||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | ||||
""" | """ | ||||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||||
进行报错。 | |||||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。 | |||||
如果匹配不上,就会进行报错。 | |||||
:param fn: 需要检测的函数,可以是 method 或者 function 。 | :param fn: 需要检测的函数,可以是 method 或者 function 。 | ||||
:param expected_params: 期待应该支持的参数。 | :param expected_params: 期待应该支持的参数。 | ||||
@@ -345,8 +345,8 @@ def apply_to_collection( | |||||
:param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据; | :param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据; | ||||
:param function: 对数据进行处理的函数; | :param function: 对数据进行处理的函数; | ||||
:param args: ``function`` 所需要的其它参数; | :param args: ``function`` 所需要的其它参数; | ||||
:param wrong_dtype: ``function`` 一定不会生效的数据类型。如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型 | |||||
那么也不会生效; | |||||
:param wrong_dtype: ``function`` 一定不会生效的数据类型。 | |||||
如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型那么也不会生效; | |||||
:param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``; | :param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``; | ||||
:param kwargs: ``function`` 所需要的其它参数; | :param kwargs: ``function`` 所需要的其它参数; | ||||
:return: 经过 ``function`` 处理后的数据集合; | :return: 经过 ``function`` 处理后的数据集合; | ||||
@@ -587,7 +587,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]): | |||||
:param seq_len: 大小为 ``(B,)`` 的长度序列; | :param seq_len: 大小为 ``(B,)`` 的长度序列; | ||||
:param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; | :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; | ||||
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 | 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 | ||||
一个 ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 | |||||
``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 | |||||
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` | :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` | ||||
""" | """ | ||||
if isinstance(seq_len, np.ndarray): | if isinstance(seq_len, np.ndarray): | ||||
@@ -0,0 +1,15 @@ | |||||
""" | |||||
torch 可使用的几种 Embedding 。 | |||||
""" | |||||
__all__ = [ | |||||
"CNNCharEmbedding", | |||||
"LSTMCharEmbedding", | |||||
"Embedding", | |||||
"StackEmbedding", | |||||
"StaticEmbedding" | |||||
] | |||||
from .char_embedding import * | |||||
from .embedding import * | |||||
from .stack_embedding import * | |||||
from .static_embedding import StaticEmbedding |
@@ -0,0 +1,287 @@ | |||||
r""" | |||||
该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是 | |||||
词的index而不需要使用词语中的char的index来获取表达。 | |||||
""" | |||||
__all__ = [ | |||||
"CNNCharEmbedding", | |||||
"LSTMCharEmbedding" | |||||
] | |||||
from typing import List | |||||
from ...envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from .embedding import TokenEmbedding | |||||
from .static_embedding import StaticEmbedding | |||||
from .utils import _construct_char_vocab_from_vocab | |||||
from .utils import get_embeddings | |||||
from ...core import logger | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...modules.torch.encoder.lstm import LSTM | |||||
class CNNCharEmbedding(TokenEmbedding): | |||||
r""" | |||||
使用 ``CNN`` 生成 ``character embedding``。``CNN`` 的结构为, char_embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||||
不同的 ``kernel`` 大小的 ``fitler`` 结果是拼起来然后通过一层``fully connected layer,`` 然后输出``word``的表示。 | |||||
Example:: | |||||
>>> import torch | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings.torch import CNNCharEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed = CNNCharEmbedding(vocab, embed_size=50) | |||||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||||
>>> outputs = embed(words) | |||||
>>> outputs.size() | |||||
# torch.Size([1, 5,50]) | |||||
""" | |||||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||||
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | |||||
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None, | |||||
requires_grad:bool=True, include_word_start_end:bool=True): | |||||
r""" | |||||
:param vocab: 词表 | |||||
:param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | |||||
:param char_emb_size: character的embed的维度。character是从vocab中生成的。默认值为50. | |||||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param dropout: 以多大的概率drop分布式表示与char embedding的输出。 | |||||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | |||||
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. | |||||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | |||||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | |||||
:param min_char_freq: character的最少出现次数。默认值为2. | |||||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||||
:param requires_grad: 是否更新权重 | |||||
:param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号; | |||||
""" | |||||
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
for kernel in kernel_sizes: | |||||
assert kernel % 2 == 1, "Only odd kernel is allowed." | |||||
assert pool_method in ('max', 'avg') | |||||
self.pool_method = pool_method | |||||
# activation function | |||||
if isinstance(activation, str): | |||||
if activation.lower() == 'relu': | |||||
self.activation = F.relu | |||||
elif activation.lower() == 'sigmoid': | |||||
self.activation = F.sigmoid | |||||
elif activation.lower() == 'tanh': | |||||
self.activation = F.tanh | |||||
elif activation is None: | |||||
self.activation = lambda x: x | |||||
elif callable(activation): | |||||
self.activation = activation | |||||
else: | |||||
raise Exception( | |||||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||||
logger.info("Start constructing character vocabulary.") | |||||
# 建立char的词表 | |||||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, | |||||
include_word_start_end=include_word_start_end) | |||||
self.char_pad_index = self.char_vocab.padding_idx | |||||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
# 对vocab进行index | |||||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||||
if include_word_start_end: | |||||
max_word_len += 2 | |||||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||||
fill_value=self.char_pad_index, dtype=torch.long)) | |||||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||||
for word, index in vocab: | |||||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed | |||||
if include_word_start_end: | |||||
word = ['<bow>'] + list(word) + ['<eow>'] | |||||
self.words_to_chars_embedding[index, :len(word)] = \ | |||||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||||
self.word_lengths[index] = len(word) | |||||
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||||
if pre_train_char_embed: | |||||
self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed) | |||||
else: | |||||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||||
self.convs = nn.ModuleList([nn.Conv1d( | |||||
self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, | |||||
padding=kernel_sizes[i] // 2) | |||||
for i in range(len(kernel_sizes))]) | |||||
self._embed_size = embed_size | |||||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||||
self.requires_grad = requires_grad | |||||
def forward(self, words): | |||||
r""" | |||||
输入words的index后,生成对应的words的表示。 | |||||
:param words: [batch_size, max_len] | |||||
:return: [batch_size, max_len, embed_size] | |||||
""" | |||||
words = self.drop_word(words) | |||||
batch_size, max_len = words.size() | |||||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||||
max_word_len = word_lengths.max() | |||||
chars = chars[:, :, :max_word_len] | |||||
# 为1的地方为mask | |||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||||
chars = self.dropout(chars) | |||||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | |||||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | |||||
for conv in self.convs] | |||||
conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters) | |||||
conv_chars = self.activation(conv_chars) | |||||
if self.pool_method == 'max': | |||||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | |||||
else: | |||||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||||
chars = self.fc(chars) | |||||
return self.dropout(chars) | |||||
class LSTMCharEmbedding(TokenEmbedding): | |||||
r""" | |||||
使用 ``LSTM`` 的方式对 ``character`` 进行 ``encode``. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | |||||
Example:: | |||||
>>> import torch | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings.torch import LSTMCharEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed = LSTMCharEmbedding(vocab, embed_size=50) | |||||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||||
>>> outputs = embed(words) | |||||
>>> outputs.size() | |||||
>>> # torch.Size([1, 5,50]) | |||||
""" | |||||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||||
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | |||||
min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None, | |||||
requires_grad:bool=True, include_word_start_end:bool=True): | |||||
r""" | |||||
:param vocab: 词表 | |||||
:param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | |||||
:param char_emb_size: character的embedding的维度。默认值为50. | |||||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。 | |||||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | |||||
:param pool_method: 支持'max', 'avg'。 | |||||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. | |||||
:param min_char_freq: character的最小出现次数。默认值为2. | |||||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | |||||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||||
:param requires_grad: 是否更新权重 | |||||
:param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号; | |||||
""" | |||||
super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
assert hidden_size % 2 == 0, "Only even kernel is allowed." | |||||
assert pool_method in ('max', 'avg') | |||||
self.pool_method = pool_method | |||||
# activation function | |||||
if isinstance(activation, str): | |||||
if activation.lower() == 'relu': | |||||
self.activation = F.relu | |||||
elif activation.lower() == 'sigmoid': | |||||
self.activation = F.sigmoid | |||||
elif activation.lower() == 'tanh': | |||||
self.activation = F.tanh | |||||
elif activation is None: | |||||
self.activation = lambda x: x | |||||
elif callable(activation): | |||||
self.activation = activation | |||||
else: | |||||
raise Exception( | |||||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||||
logger.info("Start constructing character vocabulary.") | |||||
# 建立char的词表 | |||||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, | |||||
include_word_start_end=include_word_start_end) | |||||
self.char_pad_index = self.char_vocab.padding_idx | |||||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
# 对vocab进行index | |||||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||||
if include_word_start_end: | |||||
max_word_len += 2 | |||||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||||
fill_value=self.char_pad_index, dtype=torch.long)) | |||||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||||
for word, index in vocab: | |||||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | |||||
if include_word_start_end: | |||||
word = ['<bow>'] + list(word) + ['<eow>'] | |||||
self.words_to_chars_embedding[index, :len(word)] = \ | |||||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||||
self.word_lengths[index] = len(word) | |||||
if pre_train_char_embed: | |||||
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) | |||||
else: | |||||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||||
self.fc = nn.Linear(hidden_size, embed_size) | |||||
hidden_size = hidden_size // 2 if bidirectional else hidden_size | |||||
self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True) | |||||
self._embed_size = embed_size | |||||
self.bidirectional = bidirectional | |||||
self.requires_grad = requires_grad | |||||
def forward(self, words): | |||||
r""" | |||||
输入words的index后,生成对应的words的表示。 | |||||
:param words: [batch_size, max_len] | |||||
:return: [batch_size, max_len, embed_size] | |||||
""" | |||||
words = self.drop_word(words) | |||||
batch_size, max_len = words.size() | |||||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||||
max_word_len = word_lengths.max() | |||||
chars = chars[:, :, :max_word_len] | |||||
# 为mask的地方为1 | |||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||||
chars = self.dropout(chars) | |||||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||||
char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len) | |||||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | |||||
# B x M x M x H | |||||
lstm_chars = self.activation(lstm_chars) | |||||
if self.pool_method == 'max': | |||||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||||
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H | |||||
else: | |||||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||||
chars = self.fc(chars) | |||||
return self.dropout(chars) |
@@ -0,0 +1,220 @@ | |||||
r""" | |||||
该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。 | |||||
""" | |||||
__all__ = [ | |||||
"Embedding", | |||||
] | |||||
from abc import abstractmethod | |||||
from typing import Union, Tuple | |||||
from ...envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.nn import Module | |||||
from torch import nn | |||||
else: | |||||
from ...core.utils.dummy_class import DummyClass as Module | |||||
import numpy as np | |||||
from .utils import get_embeddings | |||||
class Embedding(Module): | |||||
r""" | |||||
词向量嵌入,支持输入多种方式初始化. 可以通过 ``self.num_embeddings`` 获取词表大小; ``self.embedding_dim`` 获取 ``embedding`` 的维度. | |||||
Example:: | |||||
>>> import numpy as np | |||||
>>> from fastNLP.embeddings.torch import Embedding | |||||
>>> init_embed = (2000, 100) | |||||
>>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量 | |||||
>>> init_embed = np.zeros((2000, 100)) | |||||
>>> embed = Embedding(init_embed) # 使用numpy.ndarray的值作为初始化值初始化一个Embedding | |||||
""" | |||||
def __init__(self, init_embed:Union[Tuple[int,int],'torch.FloatTensor','nn.Embedding',np.ndarray], | |||||
word_dropout:float=0, dropout:float=0.0, unk_index:int=None): | |||||
r""" | |||||
:param init_embed: 支持传入Embedding的大小(传入tuple(int, int), | |||||
第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding; | |||||
:param word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有 | |||||
一定的regularize的作用。设置该值时,必须同时设置unk_index | |||||
:param dropout: 对Embedding的输出的dropout。 | |||||
:param unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 | |||||
""" | |||||
super(Embedding, self).__init__() | |||||
self.embed = get_embeddings(init_embed) | |||||
self.dropout = nn.Dropout(dropout) | |||||
if not isinstance(self.embed, TokenEmbedding): | |||||
if hasattr(self.embed, 'embed_size'): | |||||
self._embed_size = self.embed.embed_size | |||||
elif hasattr(self.embed, 'embedding_dim'): | |||||
self._embed_size = self.embed.embedding_dim | |||||
else: | |||||
self._embed_size = self.embed.weight.size(1) | |||||
if word_dropout > 0 and not isinstance(unk_index, int): | |||||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | |||||
else: | |||||
self._embed_size = self.embed.embed_size | |||||
unk_index = self.embed.get_word_vocab().unknown_idx | |||||
self.unk_index = unk_index | |||||
self.word_dropout = word_dropout | |||||
def forward(self, words): | |||||
r""" | |||||
:param torch.LongTensor words: [batch, seq_len] | |||||
:return: torch.Tensor : [batch, seq_len, embed_dim] | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
mask = torch.ones_like(words).float() * self.word_dropout | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self.unk_index) | |||||
words = self.embed(words) | |||||
return self.dropout(words) | |||||
@property | |||||
def num_embedding(self) -> int: | |||||
if isinstance(self.embed, nn.Embedding): | |||||
return self.embed.weight.size(0) | |||||
else: | |||||
return self.embed.num_embeddings | |||||
def __len__(self): | |||||
return len(self.embed) | |||||
@property | |||||
def embed_size(self) -> int: | |||||
return self._embed_size | |||||
@property | |||||
def embedding_dim(self) -> int: | |||||
return self._embed_size | |||||
@property | |||||
def requires_grad(self): | |||||
r""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
if not isinstance(self.embed, TokenEmbedding): | |||||
return self.embed.weight.requires_grad | |||||
else: | |||||
return self.embed.requires_grad | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
if not isinstance(self.embed, TokenEmbedding): | |||||
self.embed.weight.requires_grad = value | |||||
else: | |||||
self.embed.requires_grad = value | |||||
@property | |||||
def size(self): | |||||
if isinstance(self.embed, TokenEmbedding): | |||||
return self.embed.size | |||||
else: | |||||
return self.embed.weight.size() | |||||
class TokenEmbedding(Module): | |||||
r""" | |||||
fastNLP中各种Embedding的基类 | |||||
""" | |||||
def __init__(self, vocab, word_dropout=0.0, dropout=0.0): | |||||
super(TokenEmbedding, self).__init__() | |||||
if vocab.rebuild: | |||||
vocab.build_vocab() | |||||
assert vocab.padding is not None, "Vocabulary must have a padding entry." | |||||
self._word_vocab = vocab | |||||
self._word_pad_index = vocab.padding_idx | |||||
if word_dropout > 0: | |||||
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." | |||||
self.word_dropout = word_dropout | |||||
self._word_unk_index = vocab.unknown_idx | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
def drop_word(self, words): | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
pad_mask = words.ne(self._word_pad_index) | |||||
mask = mask.__and__(pad_mask) | |||||
words = words.masked_fill(mask, self._word_unk_index) | |||||
return words | |||||
def dropout(self, words): | |||||
r""" | |||||
对embedding后的word表示进行drop。 | |||||
:param torch.FloatTensor words: batch_size x max_len x embed_size | |||||
:return: | |||||
""" | |||||
return self.dropout_layer(words) | |||||
@property | |||||
def requires_grad(self): | |||||
r""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
requires_grads = set([param.requires_grad for param in self.parameters()]) | |||||
if len(requires_grads) == 1: | |||||
return requires_grads.pop() | |||||
else: | |||||
return None | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
for param in self.parameters(): | |||||
param.requires_grad = value | |||||
def __len__(self): | |||||
return len(self._word_vocab) | |||||
@property | |||||
def embed_size(self) -> int: | |||||
return self._embed_size | |||||
@property | |||||
def embedding_dim(self) -> int: | |||||
return self._embed_size | |||||
@property | |||||
def num_embeddings(self) -> int: | |||||
r""" | |||||
这个值可能会大于实际的embedding矩阵的大小。 | |||||
:return: | |||||
""" | |||||
return len(self._word_vocab) | |||||
def get_word_vocab(self): | |||||
r""" | |||||
返回embedding的词典。 | |||||
:return: Vocabulary | |||||
""" | |||||
return self._word_vocab | |||||
@property | |||||
def size(self): | |||||
return torch.Size(self.num_embeddings, self._embed_size) | |||||
@abstractmethod | |||||
def forward(self, words): | |||||
raise NotImplementedError |
@@ -0,0 +1,101 @@ | |||||
r""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"StackEmbedding", | |||||
] | |||||
from typing import List | |||||
from ...envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import nn | |||||
from .embedding import TokenEmbedding | |||||
from .utils import _check_vocab_has_same_index | |||||
class StackEmbedding(TokenEmbedding): | |||||
r""" | |||||
支持将多个embedding集合成一个embedding。 | |||||
Example:: | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings.torch import StaticEmbedding, StackEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) | |||||
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||||
>>> embed = StackEmbedding([embed_1, embed_2]) | |||||
""" | |||||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | |||||
r""" | |||||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | |||||
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 | |||||
:param dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
""" | |||||
vocabs = [] | |||||
for embed in embeds: | |||||
if hasattr(embed, 'get_word_vocab'): | |||||
vocabs.append(embed.get_word_vocab()) | |||||
_vocab = vocabs[0] | |||||
for vocab in vocabs[1:]: | |||||
if _vocab!=vocab: | |||||
_check_vocab_has_same_index(_vocab, vocab) | |||||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | |||||
assert isinstance(embeds, list) | |||||
for embed in embeds: | |||||
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | |||||
self.embeds = nn.ModuleList(embeds) | |||||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | |||||
def append(self, embed: TokenEmbedding): | |||||
r""" | |||||
添加一个embedding到结尾。 | |||||
:param embed: | |||||
:return: | |||||
""" | |||||
assert isinstance(embed, TokenEmbedding) | |||||
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab()) | |||||
self._embed_size += embed.embed_size | |||||
self.embeds.append(embed) | |||||
return self | |||||
def pop(self): | |||||
r""" | |||||
弹出最后一个embed | |||||
:return: | |||||
""" | |||||
embed = self.embeds.pop() | |||||
self._embed_size -= embed.embed_size | |||||
return embed | |||||
@property | |||||
def embed_size(self): | |||||
r""" | |||||
该Embedding输出的vector的最后一维的维度。 | |||||
:return: | |||||
""" | |||||
return self._embed_size | |||||
def forward(self, words): | |||||
r""" | |||||
得到多个embedding的结果,并把结果按照顺序concat起来。 | |||||
:param words: batch_size x max_len | |||||
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 | |||||
""" | |||||
outputs = [] | |||||
words = self.drop_word(words) | |||||
for embed in self.embeds: | |||||
outputs.append(embed(words)) | |||||
outputs = self.dropout(torch.cat(outputs, dim=-1)) | |||||
return outputs |
@@ -0,0 +1,407 @@ | |||||
r""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"StaticEmbedding" | |||||
] | |||||
import os | |||||
import warnings | |||||
from collections import defaultdict | |||||
from copy import deepcopy | |||||
import json | |||||
from typing import Union | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .embedding import TokenEmbedding | |||||
from ...core import logger | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||||
from ...io.file_utils import _get_file_name_base_on_postfix | |||||
VOCAB_FILENAME = 'vocab.txt' | |||||
STATIC_HYPER_FILENAME = 'static_hyper.json' | |||||
STATIC_EMBED_FILENAME = 'static.txt' | |||||
class StaticEmbedding(TokenEmbedding): | |||||
r""" | |||||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||||
当前支持自动下载的预训练vector有: | |||||
.. code:: | |||||
en: 实际为en-glove-840b-300d(常用) | |||||
en-glove-6b-50d: glove官方的50d向量 | |||||
en-glove-6b-100d: glove官方的100d向量 | |||||
en-glove-6b-200d: glove官方的200d向量 | |||||
en-glove-6b-300d: glove官方的300d向量 | |||||
en-glove-42b-300d: glove官方使用42B数据训练版本 | |||||
en-glove-840b-300d: | |||||
en-glove-twitter-27b-25d: | |||||
en-glove-twitter-27b-50d: | |||||
en-glove-twitter-27b-100d: | |||||
en-glove-twitter-27b-200d: | |||||
en-word2vec-300d: word2vec官方发布的300d向量 | |||||
en-fasttext-crawl: fasttext官方发布的300d英文预训练 | |||||
cn-char-fastnlp-100d: fastNLP训练的100d的character embedding | |||||
cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding | |||||
cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding | |||||
cn-fasttext: fasttext官方发布的300d中文预训练embedding | |||||
Example:: | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings.torch import StaticEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d') | |||||
>>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"]) | |||||
>>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True) | |||||
>>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。 | |||||
>>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) | |||||
>>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) | |||||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) | |||||
>>> embed(words) | |||||
>>> tensor([[[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849], | |||||
[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849], | |||||
[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849]]], | |||||
grad_fn=<EmbeddingBackward>) # 每种word的输出是一致的。 | |||||
""" | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True, | |||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||||
r""" | |||||
:param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化 | |||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | |||||
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||||
:param embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。 | |||||
:param requires_grad: 是否需要gradient. 默认为True | |||||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并 | |||||
inplace地修改其值。 | |||||
:param lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||||
:param dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||||
:param min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||||
:param kwargs: | |||||
* only_train_min_freq * (*bool*) -- 仅对 train 中的词语使用 ``min_freq`` 筛选; | |||||
* only_norm_found_vector * (*bool*) -- 默认为False, 是否仅对在预训练中找到的词语使用normalize; | |||||
* only_use_pretrain_word * (*bool*) -- 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现 | |||||
则为unk。如果embedding不需要更新建议设置为True。 | |||||
""" | |||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
if embedding_dim > 0: | |||||
if model_dir_or_name: | |||||
logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||||
f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | |||||
f"set `embedding_dim` to 0.") | |||||
model_dir_or_name = None | |||||
# 得到cache_path | |||||
if model_dir_or_name is None: | |||||
assert embedding_dim >= 1, "The dimension of embedding should be larger than 1." | |||||
embedding_dim = int(embedding_dim) | |||||
model_path = None | |||||
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||||
model_path = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||||
else: | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||||
kwargs['min_freq'] = min_freq | |||||
kwargs['lower'] = lower | |||||
# 根据min_freq缩小vocab | |||||
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | |||||
if truncate_vocab: | |||||
truncated_vocab = deepcopy(vocab) | |||||
truncated_vocab.min_freq = min_freq | |||||
truncated_vocab.word2idx = None | |||||
if lower: # 如果有lower,将大小写的的freq需要同时考虑到 | |||||
lowered_word_count = defaultdict(int) | |||||
for word, count in truncated_vocab.word_count.items(): | |||||
lowered_word_count[word.lower()] += count | |||||
for word in truncated_vocab.word_count.keys(): | |||||
word_count = truncated_vocab.word_count[word] | |||||
if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq: | |||||
truncated_vocab.add_word_lst([word] * (min_freq - word_count), | |||||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||||
# 只限制在train里面的词语使用min_freq筛选 | |||||
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: | |||||
for word in truncated_vocab.word_count.keys(): | |||||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq: | |||||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||||
no_create_entry=True) | |||||
truncated_vocab.build_vocab() | |||||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||||
for word, index in vocab: | |||||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||||
logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") | |||||
vocab = truncated_vocab | |||||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||||
# 读取embedding | |||||
if lower: | |||||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||||
for word, index in vocab: | |||||
if vocab._is_word_no_create_entry(word): | |||||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||||
else: | |||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||||
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||||
f"unique lowered words.") | |||||
if model_path: | |||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||||
else: | |||||
embedding = self._randomly_init_embed(len(lowered_vocab), embedding_dim, init_method) | |||||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||||
if lowered_vocab.unknown: | |||||
unknown_idx = lowered_vocab.unknown_idx | |||||
else: | |||||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow | |||||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||||
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx, dtype=torch.long).long() | |||||
for word, index in vocab: | |||||
if word not in lowered_vocab: | |||||
word = word.lower() | |||||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||||
continue # 如果不需要创建entry,已经默认unknown了 | |||||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||||
self.register_buffer('words_to_words', words_to_words) | |||||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||||
else: | |||||
if model_path: | |||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||||
else: | |||||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||||
if not self.only_norm_found_vector and normalize: | |||||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | |||||
if truncate_vocab: | |||||
for i in range(len(truncated_words_to_words)): | |||||
index_in_truncated_vocab = truncated_words_to_words[i] | |||||
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] | |||||
del self.words_to_words | |||||
self.register_buffer('words_to_words', truncated_words_to_words) | |||||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||||
padding_idx=vocab.padding_idx, | |||||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||||
sparse=False, _weight=embedding) | |||||
self._embed_size = self.embedding.weight.size(1) | |||||
self.requires_grad = requires_grad | |||||
self.kwargs = kwargs | |||||
@property | |||||
def weight(self): | |||||
return self.embedding.weight | |||||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | |||||
r""" | |||||
:param int num_embedding: embedding的entry的数量 | |||||
:param int embedding_dim: embedding的维度大小 | |||||
:param callable init_embed: 初始化方法 | |||||
:return: torch.FloatTensor | |||||
""" | |||||
embed = torch.zeros(num_embedding, embedding_dim) | |||||
if init_embed is None: | |||||
nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim)) | |||||
else: | |||||
init_embed(embed) | |||||
return embed | |||||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||||
error='ignore', init_method=None): | |||||
r""" | |||||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||||
:param str embed_filepath: 预训练的embedding的路径。 | |||||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||||
:param dtype: 读出的embedding的类型 | |||||
:param str padding: 词表中padding的token | |||||
:param str unknown: 词表中unknown的token | |||||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_ | |||||
:return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||||
""" | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||||
if not os.path.exists(embed_filepath): | |||||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||||
line = f.readline().strip() | |||||
parts = line.split() | |||||
start_idx = 0 | |||||
if len(parts) == 2: | |||||
dim = int(parts[1]) | |||||
start_idx += 1 | |||||
else: | |||||
dim = len(parts) - 1 | |||||
f.seek(0) | |||||
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word) | |||||
if vocab.padding: | |||||
matrix[vocab.padding_idx] = torch.zeros(dim) | |||||
if vocab.unknown: | |||||
matrix[vocab.unknown_idx] = torch.zeros(dim) | |||||
found_count = 0 | |||||
found_unknown = False | |||||
for idx, line in enumerate(f, start_idx): | |||||
try: | |||||
parts = line.strip().split() | |||||
word = ''.join(parts[:-dim]) | |||||
nums = parts[-dim:] | |||||
# 对齐unk与pad | |||||
if word == padding and vocab.padding is not None: | |||||
word = vocab.padding | |||||
elif word == unknown and vocab.unknown is not None: | |||||
word = vocab.unknown | |||||
found_unknown = True | |||||
if word in vocab: | |||||
index = vocab.to_index(word) | |||||
if index in matrix: | |||||
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||||
f"DEBUG for detail.") | |||||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | |||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||||
if self.only_norm_found_vector: | |||||
matrix[index] = matrix[index] / np.linalg.norm(matrix[index]) | |||||
found_count += 1 | |||||
except Exception as e: | |||||
if error == 'ignore': | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||||
else: | |||||
logger.error("Error occurred at the {} line.".format(idx)) | |||||
raise e | |||||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||||
if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 | |||||
for word, index in vocab: | |||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | |||||
else: | |||||
matrix[index] = None | |||||
# matrix中代表是需要建立entry的词 | |||||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||||
if vocab.unknown is None: # 创建一个专门的unknown | |||||
unknown_idx = len(matrix) | |||||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||||
else: | |||||
unknown_idx = vocab.unknown_idx | |||||
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx, dtype=torch.long).long()) | |||||
index = 0 | |||||
for word, index_in_vocab in vocab: | |||||
if index_in_vocab in matrix: | |||||
vec = matrix.get(index_in_vocab) | |||||
if vec is not None: # 使用找到的vector, 如果为None说明需要训练 | |||||
vectors[index] = vec | |||||
self.words_to_words[index_in_vocab] = index | |||||
index += 1 | |||||
return vectors | |||||
def forward(self, words): | |||||
r""" | |||||
传入words的index | |||||
:param words: torch.LongTensor, [batch_size, max_len] | |||||
:return: torch.FloatTensor, [batch_size, max_len, embed_size] | |||||
""" | |||||
if hasattr(self, 'words_to_words'): | |||||
words = self.words_to_words[words] | |||||
words = self.drop_word(words) | |||||
words = self.embedding(words) | |||||
words = self.dropout(words) | |||||
return words | |||||
def save(self, folder): | |||||
""" | |||||
将embedding存储到folder下,之后可以通过使用load方法读取 | |||||
:param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. | |||||
其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, | |||||
第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 | |||||
:return: | |||||
""" | |||||
os.makedirs(folder, exist_ok=True) | |||||
vocab = self.get_word_vocab() | |||||
vocab_fp = os.path.join(folder, VOCAB_FILENAME) | |||||
vocab.save(vocab_fp) | |||||
kwargs = self.kwargs.copy() | |||||
kwargs['dropout'] = self.dropout_layer.p | |||||
kwargs['word_dropout'] = self.word_dropout | |||||
kwargs['requires_grad'] = self.requires_grad | |||||
kwargs['only_norm_found_vector'] = False | |||||
kwargs['only_use_pretrain_word'] = True | |||||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: | |||||
json.dump(kwargs, f, indent=2) | |||||
with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: | |||||
f.write('{}\n'.format(' '*30)) # 留白之后再来填写 | |||||
word_count = 0 | |||||
saved_word = {} | |||||
valid_word_count = 0 | |||||
for i in range(len(self.words_to_words)): | |||||
word = vocab.to_word(i) | |||||
if not vocab._is_word_no_create_entry(word): | |||||
word_count += 1 | |||||
if kwargs['lower']: | |||||
word = word.lower() | |||||
if word in saved_word: | |||||
continue | |||||
saved_word[word] = 1 | |||||
vec_i = self.words_to_words[i] | |||||
if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: | |||||
continue | |||||
vec = self.embedding.weight.data[vec_i].tolist() | |||||
vec_str = ' '.join(map(str, vec)) | |||||
f.write(f'{word} {vec_str}\n') | |||||
valid_word_count += 1 | |||||
f.seek(0) | |||||
f.write('{} {}'.format(valid_word_count, self.embedding_dim)) | |||||
logger.debug(f"StaticEmbedding has been saved to {folder}.") | |||||
@classmethod | |||||
def load(cls, folder): | |||||
""" | |||||
:param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json | |||||
:return: | |||||
""" | |||||
for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) | |||||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
logger.info(f"Load StaticEmbedding from {folder}.") | |||||
embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) | |||||
return embed | |||||
@@ -0,0 +1,106 @@ | |||||
r""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
import numpy as np | |||||
from ...envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import nn as nn | |||||
from ...core.vocabulary import Vocabulary | |||||
__all__ = [ | |||||
'get_embeddings', | |||||
'get_sinusoid_encoding_table' | |||||
] | |||||
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): | |||||
r""" | |||||
给定一个word的vocabulary生成character的vocabulary. | |||||
:param vocab: 从vocab | |||||
:param min_freq: | |||||
:param include_word_start_end: 是否需要包含特殊的<bow>和<eos> | |||||
:return: | |||||
""" | |||||
char_vocab = Vocabulary(min_freq=min_freq) | |||||
for word, index in vocab: | |||||
if not vocab._is_word_no_create_entry(word): | |||||
char_vocab.add_word_lst(list(word)) | |||||
if include_word_start_end: | |||||
char_vocab.add_word_lst(['<bow>', '<eow>']) | |||||
return char_vocab | |||||
def get_embeddings(init_embed, padding_idx=None): | |||||
r""" | |||||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | |||||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | |||||
返回原对象。 | |||||
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | |||||
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | |||||
传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | |||||
:param padding_idx: 当传入tuple时,padding_idx有效 | |||||
:return nn.Embedding: embeddings | |||||
""" | |||||
if isinstance(init_embed, tuple): | |||||
res = nn.Embedding( | |||||
num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) | |||||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | |||||
b=np.sqrt(3 / res.weight.data.size(1))) | |||||
elif isinstance(init_embed, nn.Module): | |||||
res = init_embed | |||||
elif isinstance(init_embed, torch.Tensor): | |||||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||||
elif isinstance(init_embed, np.ndarray): | |||||
init_embed = torch.tensor(init_embed, dtype=torch.float32) | |||||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||||
else: | |||||
raise TypeError( | |||||
'invalid init_embed type: {}'.format((type(init_embed)))) | |||||
return res | |||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
""" | |||||
sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos | |||||
:param int n_position: 一共多少个position | |||||
:param int d_hid: 多少维度,需要为偶数 | |||||
:param padding_idx: | |||||
:return: torch.FloatTensor, shape为n_position x d_hid | |||||
""" | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) | |||||
def _check_vocab_has_same_index(vocab, other_vocab): | |||||
""" | |||||
检查两个vocabulary是否含有相同的word idx | |||||
:param Vocabulary vocab: | |||||
:param Vocabulary other_vocab: | |||||
:return: | |||||
""" | |||||
if other_vocab != vocab: | |||||
for word, word_ix in vocab: | |||||
other_word_idx = other_vocab.to_index(word) | |||||
assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}." |
@@ -20,13 +20,17 @@ def is_cur_env_distributed() -> bool: | |||||
""" | """ | ||||
单卡模式该函数一定返回 False; | 单卡模式该函数一定返回 False; | ||||
注意进程 0 在多卡的训练模式下前后的值是不一样的,例如在开启多卡的 driver 之前,在进程 0 上的该函数返回 False;但是在开启后,在进程 0 上 | 注意进程 0 在多卡的训练模式下前后的值是不一样的,例如在开启多卡的 driver 之前,在进程 0 上的该函数返回 False;但是在开启后,在进程 0 上 | ||||
的该函数返回的值是 True; | |||||
多卡模式下除了进程 0 外的其它进程返回的值一定是 True; | |||||
的该函数返回的值是 True;多卡模式下除了进程 0 外的其它进程返回的值一定是 True; | |||||
""" | """ | ||||
return FASTNLP_GLOBAL_RANK in os.environ | return FASTNLP_GLOBAL_RANK in os.environ | ||||
def get_global_rank(): | |||||
def get_global_rank()->int: | |||||
""" | |||||
获取当前进程的 global_rank 。 | |||||
:return: | |||||
""" | |||||
return int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | return int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | ||||
@@ -64,7 +68,7 @@ def rank_zero_call(fn: Callable): | |||||
@contextmanager | @contextmanager | ||||
def fastnlp_no_sync_context(level=2): | def fastnlp_no_sync_context(level=2): | ||||
""" | """ | ||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | |||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有 1 卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | |||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。 | 如果为 2 表示 barrier 与 gather/broadcast 都失效。 | ||||
:param int level: 可选 [0, 1, 2] | :param int level: 可选 [0, 1, 2] | ||||
@@ -109,13 +109,9 @@ __all__ = [ | |||||
"CMRC2018BertPipe", | "CMRC2018BertPipe", | ||||
'ModelLoader', | |||||
'ModelSaver', | |||||
] | ] | ||||
from .data_bundle import DataBundle | from .data_bundle import DataBundle | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .loader import * | from .loader import * | ||||
from .model_io import ModelLoader, ModelSaver | |||||
from .pipe import * | from .pipe import * |
@@ -249,7 +249,7 @@ class DataBundle: | |||||
return self | return self | ||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||||
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | ||||
@@ -263,8 +263,8 @@ class DataBundle: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param show_progress_bar: 是否显示tqdm进度条 | |||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar: 是否显示进度条 | |||||
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
@@ -1,71 +0,0 @@ | |||||
r""" | |||||
用于载入和保存模型 | |||||
""" | |||||
__all__ = [ | |||||
"ModelLoader", | |||||
"ModelSaver" | |||||
] | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
class ModelLoader: | |||||
r""" | |||||
用于读取模型 | |||||
""" | |||||
def __init__(self): | |||||
super(ModelLoader, self).__init__() | |||||
@staticmethod | |||||
def load_pytorch(empty_model, model_path): | |||||
r""" | |||||
从 ".pkl" 文件读取 PyTorch 模型 | |||||
:param empty_model: 初始化参数的 PyTorch 模型 | |||||
:param str model_path: 模型保存的路径 | |||||
""" | |||||
empty_model.load_state_dict(torch.load(model_path)) | |||||
@staticmethod | |||||
def load_pytorch_model(model_path): | |||||
r""" | |||||
读取整个模型 | |||||
:param str model_path: 模型保存的路径 | |||||
""" | |||||
return torch.load(model_path) | |||||
class ModelSaver(object): | |||||
r""" | |||||
用于保存模型 | |||||
Example:: | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
""" | |||||
def __init__(self, save_path): | |||||
r""" | |||||
:param save_path: 模型保存的路径 | |||||
""" | |||||
self.save_path = save_path | |||||
def save_pytorch(self, model, param_only=True): | |||||
r""" | |||||
把 PyTorch 模型存入 ".pkl" 文件 | |||||
:param model: PyTorch 模型 | |||||
:param bool param_only: 是否只保存模型的参数(否则保存整个模型) | |||||
""" | |||||
if param_only is True: | |||||
torch.save(model.state_dict(), self.save_path) | |||||
else: | |||||
torch.save(model, self.save_path) |
@@ -202,12 +202,12 @@ def jittor2torch(batch: Any, device: str = None, no_gradient: bool = None) -> An | |||||
.. note:: | .. note:: | ||||
注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换 | |||||
至 :class:`torch.Tensor` 的过程中无法保留原张量的梯度。 | |||||
注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换至 | |||||
:class:`torch.Tensor` 的过程中无法保留原张量的梯度。 | |||||
:param batch: 包含 :class:`jittor.Var` 类型的数据集合; | :param batch: 包含 :class:`jittor.Var` 类型的数据集合; | ||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; | :param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; | ||||
:param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效。 | |||||
:param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效; | |||||
:return: 转换后的数据; | :return: 转换后的数据; | ||||
""" | """ | ||||
@@ -0,0 +1,5 @@ | |||||
__all__ = [ | |||||
"LSTM", | |||||
] | |||||
from .lstm import LSTM |
@@ -0,0 +1,82 @@ | |||||
r"""undocumented | |||||
轻量封装的 Pytorch LSTM 模块. | |||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||||
""" | |||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.utils.rnn as rnn | |||||
class LSTM(nn.Module): | |||||
r""" | |||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||||
为1; 且可以应对DataParallel中LSTM的使用问题。 | |||||
""" | |||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||||
bidirectional=False, bias=True): | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param dropout: 层间dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
""" | |||||
super(LSTM, self).__init__() | |||||
self.batch_first = batch_first | |||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||||
dropout=dropout, bidirectional=bidirectional) | |||||
self.init_param() | |||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias' in name: | |||||
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 | |||||
param.data.fill_(0) | |||||
n = param.size(0) | |||||
start, end = n // 4, n // 2 | |||||
param.data[start:end].fill_(1) | |||||
else: | |||||
nn.init.xavier_uniform_(param) | |||||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||||
r""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | |||||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
:return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. | |||||
""" | |||||
batch_size, max_len, _ = x.size() | |||||
if h0 is not None and c0 is not None: | |||||
hx = (h0, c0) | |||||
else: | |||||
hx = None | |||||
if seq_len is not None and not isinstance(x, rnn.PackedSequence): | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
if self.batch_first: | |||||
x = x[sort_idx] | |||||
else: | |||||
x = x[:, sort_idx] | |||||
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) | |||||
output, hx = self.lstm(x, hx) # -> [N,L,C] | |||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
if self.batch_first: | |||||
output = output[unsort_idx] | |||||
else: | |||||
output = output[:, unsort_idx] | |||||
hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] | |||||
else: | |||||
output, hx = self.lstm(x, hx) | |||||
return output, hx |
@@ -314,7 +314,7 @@ class PretrainedConfig: | |||||
# TPU arguments | # TPU arguments | ||||
if kwargs.pop("xla_device", None) is not None: | if kwargs.pop("xla_device", None) is not None: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | ||||
"safely remove it from your `config.json` file." | "safely remove it from your `config.json` file." | ||||
) | ) | ||||
@@ -474,7 +474,7 @@ class PretrainedConfig: | |||||
""" | """ | ||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||||
) | ) | ||||
@@ -564,9 +564,9 @@ class PretrainedConfig: | |||||
raise EnvironmentError(msg) | raise EnvironmentError(msg) | ||||
if resolved_config_file == config_file: | if resolved_config_file == config_file: | ||||
logger.info(f"loading configuration file {config_file}") | |||||
logger.debug(f"loading configuration file {config_file}") | |||||
else: | else: | ||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") | |||||
logger.debug(f"loading configuration file {config_file} from cache at {resolved_config_file}") | |||||
return config_dict, kwargs | return config_dict, kwargs | ||||
@@ -603,7 +603,7 @@ class PretrainedConfig: | |||||
for key in to_remove: | for key in to_remove: | ||||
kwargs.pop(key, None) | kwargs.pop(key, None) | ||||
logger.info(f"Model config {config}") | |||||
logger.debug(f"Model config {config}") | |||||
if return_unused_kwargs: | if return_unused_kwargs: | ||||
return config, kwargs | return config, kwargs | ||||
else: | else: | ||||
@@ -82,6 +82,52 @@ def filelock(path): | |||||
except: | except: | ||||
pass | pass | ||||
class HfFolder: | |||||
""" | |||||
hugging_face.HfFolder | |||||
version = 0.5.1 | |||||
""" | |||||
path_token = os.path.expanduser("~/.huggingface/token") | |||||
@classmethod | |||||
def save_token(cls, token): | |||||
""" | |||||
Save token, creating folder as needed. | |||||
Args: | |||||
token (`str`): | |||||
The token to save to the [`HfFolder`] | |||||
""" | |||||
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) | |||||
with open(cls.path_token, "w+") as f: | |||||
f.write(token) | |||||
@classmethod | |||||
def get_token(cls): | |||||
""" | |||||
Retrieves the token | |||||
Returns: | |||||
`str` or `None`: The token, `None` if it doesn't exist. | |||||
""" | |||||
try: | |||||
with open(cls.path_token, "r") as f: | |||||
return f.read() | |||||
except FileNotFoundError: | |||||
pass | |||||
@classmethod | |||||
def delete_token(cls): | |||||
""" | |||||
Deletes the token from storage. Does not fail if token does not exist. | |||||
""" | |||||
try: | |||||
os.remove(cls.path_token) | |||||
except FileNotFoundError: | |||||
pass | |||||
def is_offline_mode(): | def is_offline_mode(): | ||||
return _is_offline_mode | return _is_offline_mode | ||||
@@ -629,11 +675,10 @@ def get_from_cache( | |||||
if isinstance(use_auth_token, str): | if isinstance(use_auth_token, str): | ||||
headers["authorization"] = f"Bearer {use_auth_token}" | headers["authorization"] = f"Bearer {use_auth_token}" | ||||
elif use_auth_token: | elif use_auth_token: | ||||
raise RuntimeError("`use_auth_token=True` is not supported in FastNLP now") | |||||
# token = HfFolder.get_token() | |||||
# if token is None: | |||||
# raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") | |||||
# headers["authorization"] = f"Bearer {token}" | |||||
token = HfFolder.get_token() | |||||
if token is None: | |||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") | |||||
headers["authorization"] = f"Bearer {token}" | |||||
url_to_download = url | url_to_download = url | ||||
etag = None | etag = None | ||||
@@ -791,13 +836,7 @@ def get_list_of_files( | |||||
if isinstance(use_auth_token, str): | if isinstance(use_auth_token, str): | ||||
token = use_auth_token | token = use_auth_token | ||||
elif use_auth_token is True: | elif use_auth_token is True: | ||||
# token = HfFolder.get_token() | |||||
path_token = os.path.expanduser("~/.huggingface/token") | |||||
try: | |||||
with open(path_token, "r") as f: | |||||
token = f.read() | |||||
except FileNotFoundError: | |||||
token = None | |||||
token = HfFolder.get_token() | |||||
else: | else: | ||||
token = None | token = None | ||||
# model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( | # model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( | ||||
@@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng | |||||
stopping_max_length = stopping_criteria.max_length | stopping_max_length = stopping_criteria.max_length | ||||
new_stopping_criteria = deepcopy(stopping_criteria) | new_stopping_criteria = deepcopy(stopping_criteria) | ||||
if stopping_max_length is not None and stopping_max_length != max_length: | if stopping_max_length is not None and stopping_max_length != max_length: | ||||
logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
elif stopping_max_length is None: | elif stopping_max_length is None: | ||||
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | ||||
return new_stopping_criteria | return new_stopping_criteria |
@@ -429,7 +429,7 @@ class GenerationMixin: | |||||
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
return pad_token_id | return pad_token_id | ||||
@@ -912,7 +912,7 @@ class GenerationMixin: | |||||
# special case if pad_token_id is not defined | # special case if pad_token_id is not defined | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
# Storing encoder_input_ids for logits_processor that could use them | # Storing encoder_input_ids for logits_processor that could use them | ||||
@@ -352,7 +352,7 @@ class ModuleUtilsMixin: | |||||
if token_inputs: | if token_inputs: | ||||
return sum([token_input.numel() for token_input in token_inputs]) | return sum([token_input.numel() for token_input in token_inputs]) | ||||
else: | else: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" | "Could not estimate the number of tokens of the input, floating-point operations will not be computed" | ||||
) | ) | ||||
return 0 | return 0 | ||||
@@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
# tie weights recursively | # tie weights recursively | ||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | ||||
if len(uninitialized_encoder_weights) > 0: | if len(uninitialized_encoder_weights) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||||
) | ) | ||||
@@ -1260,9 +1260,9 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
raise EnvironmentError(msg) | raise EnvironmentError(msg) | ||||
if resolved_archive_file == archive_file: | if resolved_archive_file == archive_file: | ||||
logger.info(f"loading weights file {archive_file}") | |||||
logger.debug(f"loading weights file {archive_file}") | |||||
else: | else: | ||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") | |||||
logger.debug(f"loading weights file {archive_file} from cache at {resolved_archive_file}") | |||||
else: | else: | ||||
resolved_archive_file = None | resolved_archive_file = None | ||||
@@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | ||||
if len(unexpected_keys) > 0: | if len(unexpected_keys) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | ||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | ||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | ||||
@@ -15,7 +15,7 @@ | |||||
""" Auto Tokenizer class. """ | """ Auto Tokenizer class. """ | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union | |||||
from typing import TYPE_CHECKING, Optional, Tuple | |||||
from ...file_utils import ( | from ...file_utils import ( | ||||
is_sentencepiece_available, | is_sentencepiece_available, | ||||
@@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig): | |||||
# ensure backward compatibility for BART CNN models | # ensure backward compatibility for BART CNN models | ||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | ||||
self.forced_bos_token_id = self.bos_token_id | self.forced_bos_token_id = self.bos_token_id | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | ||||
"The config can simply be saved and uploaded again to be fixed." | "The config can simply be saved and uploaded again to be fixed." | ||||
) | ) |
@@ -1700,9 +1700,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
continue | continue | ||||
if file_path == resolved_vocab_files[file_id]: | if file_path == resolved_vocab_files[file_id]: | ||||
logger.info(f"loading file {file_path}") | |||||
logger.debug(f"loading file {file_path}") | |||||
else: | else: | ||||
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") | |||||
logger.debug(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") | |||||
return cls._from_pretrained( | return cls._from_pretrained( | ||||
resolved_vocab_files, | resolved_vocab_files, | ||||
@@ -74,7 +74,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
@@ -121,7 +121,7 @@ def test_model_checkpoint_callback_1( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_10" in all_saved_model_paths | assert "model-epoch_10" in all_saved_model_paths | ||||
assert "model-epoch_4-batch_123" in all_saved_model_paths | assert "model-epoch_4-batch_123" in all_saved_model_paths | ||||
@@ -144,7 +144,7 @@ def test_model_checkpoint_callback_1( | |||||
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_9" in all_saved_model_paths | assert "model-epoch_9" in all_saved_model_paths | ||||
assert "model-last" in all_saved_model_paths | assert "model-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
@@ -206,7 +206,7 @@ def test_model_checkpoint_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("only_state_dict", [True]) | @pytest.mark.parametrize("only_state_dict", [True]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_model_checkpoint_callback_2( | def test_model_checkpoint_callback_2( | ||||
@@ -259,7 +259,7 @@ def test_model_checkpoint_callback_2( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | ||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
@@ -299,7 +299,7 @@ def test_model_checkpoint_callback_2( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
@@ -347,7 +347,7 @@ def test_trainer_checkpoint_callback_1( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-epoch_7" in all_saved_model_paths | assert "trainer-epoch_7" in all_saved_model_paths | ||||
assert "trainer-epoch_4-batch_123" in all_saved_model_paths | assert "trainer-epoch_4-batch_123" in all_saved_model_paths | ||||
@@ -371,7 +371,7 @@ def test_trainer_checkpoint_callback_1( | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -417,7 +417,7 @@ def test_trainer_checkpoint_callback_1( | |||||
n_epochs=13, | n_epochs=13, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -489,7 +489,7 @@ def test_load_state(model_and_optimizers): | |||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.load(folder=epoch_2_path) | |||||
trainer.load_checkpoint(folder=epoch_2_path) | |||||
with Capturing() as output: | with Capturing() as output: | ||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | ||||
@@ -503,7 +503,7 @@ def test_load_state(model_and_optimizers): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.skip("Skip transformers test for now.") | @pytest.mark.skip("Skip transformers test for now.") | ||||
@@ -675,7 +675,7 @@ def test_trainer_checkpoint_callback_2( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-epoch_1-batch_200" in all_saved_model_paths | assert "trainer-epoch_1-batch_200" in all_saved_model_paths | ||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | ||||
@@ -695,7 +695,7 @@ def test_trainer_checkpoint_callback_2( | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -740,7 +740,7 @@ def test_trainer_checkpoint_callback_2( | |||||
output_mapping=bert_output_mapping, | output_mapping=bert_output_mapping, | ||||
metrics={"acc": acc}, | metrics={"acc": acc}, | ||||
) | ) | ||||
trainer.load(folder, model_load_fn=model_load_fn) | |||||
trainer.load_checkpoint(folder, model_load_fn=model_load_fn) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -72,7 +72,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("save_folder", ['save_models', None]) | @pytest.mark.parametrize("save_folder", ['save_models', None]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -98,7 +98,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -183,7 +183,7 @@ def test_model_more_evaluate_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -256,7 +256,7 @@ def test_trainer_checkpoint_callback_1( | |||||
evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
) | ) | ||||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | ||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -85,7 +85,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model, | model=model, | ||||
driver="torch_ddp", | |||||
driver="torch", | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
@@ -73,7 +73,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model, | model=model, | ||||
driver="torch_ddp", | |||||
driver="torch", | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
@@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_wo_auto_param_call( | def test_torch_wo_auto_param_call( | ||||
driver, | driver, | ||||
@@ -160,7 +160,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | ||||
此时对应 driver.load 中的情况 | |||||
此时对应 driver.load_checkpoint 中的情况 | |||||
""" | """ | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -626,9 +626,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
@@ -644,7 +644,7 @@ class TestSaveLoad: | |||||
rank=self.driver2.global_rank, | rank=self.driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -736,9 +736,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | ||||
@@ -752,7 +752,7 @@ class TestSaveLoad: | |||||
self.dataset, | self.dataset, | ||||
batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -615,16 +615,16 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -697,9 +697,9 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
@@ -709,7 +709,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
dataset, | dataset, | ||||
batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -186,7 +186,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | ||||
此时对应 driver.load 中的情况 | |||||
此时对应 driver.load_checkpoint 中的情况 | |||||
""" | """ | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -648,7 +648,7 @@ class TestSaveLoad: | |||||
# 保存状态 | # 保存状态 | ||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
@@ -663,7 +663,7 @@ class TestSaveLoad: | |||||
rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -754,9 +754,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
@@ -765,7 +765,7 @@ class TestSaveLoad: | |||||
rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -37,28 +37,6 @@ def test_get_single_device(driver, device): | |||||
driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
assert isinstance(driver, TorchSingleDriver) | assert isinstance(driver, TorchSingleDriver) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[0, 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_2(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
@@ -66,7 +44,7 @@ def test_get_ddp_2(driver, device): | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | |||||
["torch"] | |||||
) | ) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_ddp(driver, device): | def test_get_ddp(driver, device): | ||||
@@ -79,21 +57,6 @@ def test_get_ddp(driver, device): | |||||
assert isinstance(driver, TorchDDPDriver) | assert isinstance(driver, TorchDDPDriver) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize( | |||||
("driver", "device"), | |||||
[("torch_ddp", "cpu")] | |||||
) | |||||
def test_get_ddp_cpu(driver, device): | |||||
""" | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
@@ -101,7 +64,7 @@ def test_get_ddp_cpu(driver, device): | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | |||||
["torch"] | |||||
) | ) | ||||
def test_device_out_of_range(driver, device): | def test_device_out_of_range(driver, device): | ||||
""" | """ | ||||
@@ -595,12 +595,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -664,12 +664,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||