@@ -49,12 +49,17 @@ class Callback: | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
r""" | |||
在 `Trainer` 初始化后会被触发; | |||
:param trainer: ``Trainer`` 实例; | |||
:param driver: ``Trainer`` 中的 ``driver`` 实例; | |||
""" | |||
pass | |||
def on_sanity_check_begin(self, trainer): | |||
r""" | |||
在 '预跑'检测 开始前会被触发; | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -62,9 +67,8 @@ class Callback: | |||
r""" | |||
在 '预跑'检测 开始后会被触发; | |||
:param trainer: | |||
:param sanity_check_res: 预跑的 evaluate 结果 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`; | |||
""" | |||
pass | |||
@@ -72,8 +76,7 @@ class Callback: | |||
r""" | |||
在训练开始前会被触发; | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -81,8 +84,7 @@ class Callback: | |||
r""" | |||
在训练完成后会被触发; | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -90,8 +92,7 @@ class Callback: | |||
r""" | |||
在训练过程中的每一个 epoch 开始前会被触发; | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -99,8 +100,7 @@ class Callback: | |||
r""" | |||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -108,8 +108,7 @@ class Callback: | |||
r""" | |||
在训练过程中准备取出下一个 batch 的数据时触发 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
@@ -117,179 +116,161 @@ class Callback: | |||
r""" | |||
在训练过程中拿到当前的 batch 数据后会被触发; | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_train_batch_begin(self, trainer, batch, indices): | |||
r""" | |||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
在取得数据,执行完 ``input_mapping`` (如果 ``Trainer`` 传有该参数),并且移动 ``batch`` 中的 ``tensor`` 到了指定设备。 | |||
其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。 | |||
如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 ``key`` 或 修改其中的 ``value`` 会影响到输入到 ``model`` 的中的 ``batch`` 数据。 | |||
:param trainer: `fastNLP.Trainer` | |||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值, | |||
:param trainer: ``Trainer`` 实例; | |||
:param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。 | |||
:param list[int] indices: 当前的 ``batch`` 是 ``dataset`` 中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值, | |||
其它时候为 None 。 | |||
""" | |||
pass | |||
def on_train_batch_end(self, trainer): | |||
""" | |||
r""" | |||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
执行。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_exception(self, trainer, exception): | |||
""" | |||
r""" | |||
在训练过程遇到异常时调用。 | |||
:param trainer: | |||
:param exception: 遭遇的异常。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param exception: 遭遇的异常; | |||
""" | |||
pass | |||
def on_save_model(self, trainer): | |||
""" | |||
r""" | |||
当调用 Trainer.save_model() 时调用,此刻模型还未保存。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_load_model(self, trainer): | |||
""" | |||
r""" | |||
当调用 Trainer.load_model() 加载模型时调用,此刻模型还未加载。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
r""" | |||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 Trainer.save_checkpoint() 函数时),该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
r""" | |||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态), | |||
参数 states 为 on_save_checkpoint() | |||
的返回值。 | |||
参数 states 为 on_save_checkpoint() 的返回值。 | |||
:param trainer: | |||
:param trainer: ``Trainer`` 实例; | |||
:param states: | |||
:return: | |||
""" | |||
pass | |||
def on_before_backward(self, trainer, outputs): | |||
""" | |||
r""" | |||
在 backward 前执行。 | |||
:param trainer: | |||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。 | |||
""" | |||
pass | |||
def on_after_backward(self, trainer): | |||
""" | |||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
r""" | |||
在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步, | |||
因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 ``step`` 各卡上梯度不一致的问题。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
""" | |||
r""" | |||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||
""" | |||
pass | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
""" | |||
r""" | |||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||
""" | |||
pass | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
""" | |||
r""" | |||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||
""" | |||
pass | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
""" | |||
r""" | |||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param optimizers: 优化器,内容为在 ``Trainer`` 初始化时传入的值。 | |||
""" | |||
pass | |||
def on_evaluate_begin(self, trainer): | |||
""" | |||
r""" | |||
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | |||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
:param trainer: | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
""" | |||
pass | |||
def on_evaluate_end(self, trainer, results): | |||
""" | |||
r""" | |||
结束 evaluate 时调用,并把 evaluate 的结果传入。 | |||
:param trainer: | |||
:param results: Evaluate 的结果,一般是个 dict 。 | |||
:return: | |||
:param trainer: ``Trainer`` 实例; | |||
:param results: ``Trainer`` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``; | |||
""" | |||
pass | |||
@property | |||
def callback_name(self): | |||
""" | |||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||
r""" | |||
``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 ``on_load_checkpoint()`` 函数。 | |||
:return: | |||
:return: 返回用于区分该 ``callback`` 实例的 ``name``; | |||
""" | |||
return self.__class__.__name__ | |||
@property | |||
def need_reproducible_sampler(self) -> bool: | |||
""" | |||
r""" | |||
当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | |||
:return: | |||
""" | |||
return False | |||
@@ -29,11 +29,10 @@ def _transfer(func): | |||
return wrapper | |||
def prepare_callbacks(callbacks, progress_bar): | |||
def prepare_callbacks(callbacks, progress_bar: str): | |||
""" | |||
:param callbacks: | |||
:param progress_bar: | |||
:param callbacks: 对用户传入的类 ``callback`` 进行检查,查看是否是否继承了我们的 ``Callback`` 类; | |||
:param progress_bar: 选择怎样的 ``progress_bar`` 给 ``Trainer`` 使用; | |||
:return: | |||
""" | |||
_callbacks = [] | |||
@@ -81,7 +80,7 @@ class CallbackManager: | |||
2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | |||
3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | |||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | |||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 类; | |||
""" | |||
self._need_reproducible_sampler = False | |||
@@ -158,7 +157,6 @@ class CallbackManager: | |||
"filter_states": {"on_train_begin": filter1.state_dict(), ...} | |||
} | |||
} | |||
""" | |||
states = {} | |||
@@ -1,7 +1,7 @@ | |||
import os | |||
import signal | |||
import sys | |||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple | |||
from typing import Sequence, List, Optional, Callable, Dict, Union, Tuple | |||
from abc import ABC, abstractmethod | |||
from datetime import datetime | |||
from pathlib import Path | |||
@@ -19,13 +19,11 @@ class Driver(ABC): | |||
r""" | |||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | |||
fastNLP 提供的 driver 实例都会同时被 Trainer 和 Evaluator 调用; | |||
:param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 `torch.nn.DataParallel` 或者 | |||
`torch.nn.parallel.DistributedDataParallel` 包裹过的模型; | |||
""" | |||
def __init__(self, model): | |||
r""" | |||
:param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 `torch.nn.DataParallel` 或者 | |||
`torch.nn.parallel.DistributedDataParallel` 包裹过的模型; | |||
""" | |||
self.model = model | |||
# 这些属性用于 open_subprocess 和 on_exception 函数协同配合; | |||
@@ -36,24 +34,25 @@ class Driver(ABC): | |||
def setup(self): | |||
r""" | |||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | |||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||
多卡的 ``driver`` 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||
""" | |||
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): | |||
r""" | |||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 | |||
:param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本; | |||
:param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader | |||
切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。 | |||
其中 trainer 中 kwargs 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``dist``; | |||
否则为 ``None``,evaluator 中的 kwargs 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``unrepeatdist``,否则为 ``None``; | |||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load_checkpoint 函数在调用; | |||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
可以可以加载。 | |||
:param reproducible: 如果为 ``False``,不要做任何考虑;如果为 ``True``,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
该状态可以加载到一个全新的 dataloader 中然后恢复其状态; | |||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
@@ -65,50 +64,50 @@ class Driver(ABC): | |||
def set_deterministic_dataloader(self, dataloader): | |||
r""" | |||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 torch 的 dataloader,其 | |||
需要将 worker_init_fn 替换; | |||
为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 ``pytorch`` 的 ``dataloader``,其 | |||
需要将 ``worker_init_fn`` 替换; | |||
""" | |||
def set_sampler_epoch(self, dataloader, cur_epoch_idx): | |||
r""" | |||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||
对于分布式的 ``sampler``,例如 ``pytorch`` 的 ``DistributedSampler``,其需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的; | |||
``dataloader`` 中可能真正发挥作用的是 ``batch_sampler`` 也可能是 ``sampler``。 | |||
:param dataloader: 需要设置 epoch 的 dataloader 。 | |||
:param cur_epoch_idx: 当前是第几个 epoch; | |||
:param dataloader: 需要设置 ``epoch`` 的 ``dataloader``; | |||
:param cur_epoch_idx: 当前是第几个 ``epoch``; | |||
""" | |||
@abstractmethod | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
""" | |||
通过调用 `fn` 来实现训练时的前向传播过程; | |||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||
r""" | |||
通过调用 ``fn`` 来实现训练时的前向传播过程; | |||
注意 ``Trainer`` 和 ``Evaluator`` 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 ``fn`` 是函数 ``get_model_call_fn`` 所返回的 | |||
函数; | |||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
:param fn: 调用该函数进行一次计算。 | |||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函 | |||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
:param signature_fn: 由 ``Trainer`` 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 ``Dict`` 的时候,我们会自动调用 ``auto_param_call`` 函 | |||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 ``DistributedDataParallel`` 的调用函数是 ``forward``,但是需要其函数签名为 ``model.module.forward``; | |||
:return: 返回由 ``fn`` 返回的结果(应当为一个 ``dict`` 或者 ``dataclass``,但是不需要我们去检查); | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | |||
@abstractmethod | |||
def get_model_call_fn(self, fn: str) -> Tuple: | |||
""" | |||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||
r""" | |||
该函数会接受 ``Trainer`` 的 ``train_fn`` 或者 ``Evaluator`` 的 ``evaluate_fn``,返回一个实际用于调用 ``driver.model_call`` 时传入的函数参数; | |||
该函数会在 ``Trainer`` 和 ``Evaluator`` 在 ``driver.setup`` 函数之后调用; | |||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 ``train step`` 或者 ``evaluate step`` 是通过额外的参数 ``train_fn`` 和 | |||
``evaluate_fn`` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 ``train step fn`` 和 | |||
``evaluate step fn`` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
``evaluate step fn`` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||
这一函数应当通过参数 ``fn`` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 ``fn``,则默认调用模型的 ``forward`` | |||
函数,然后给出 warning; | |||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 ``fn`` 则直接报错; | |||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
@@ -121,6 +120,9 @@ class Driver(ABC): | |||
@property | |||
def model(self): | |||
r""" | |||
:return: 返回 driver 中在实际训练或者评测时所使用的模型; | |||
""" | |||
return self._model | |||
@model.setter | |||
@@ -147,6 +149,9 @@ class Driver(ABC): | |||
@property | |||
def model_device(self): | |||
r""" | |||
:return: 返回 driver 中模型实际所在的设备; | |||
""" | |||
return self._model_device | |||
@model_device.setter | |||
@@ -155,28 +160,30 @@ class Driver(ABC): | |||
@property | |||
def data_device(self): | |||
""" | |||
:return: 返回 driver 中数据默认会被迁移到的设备; | |||
""" | |||
return self.model_device | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
""" | |||
r""" | |||
对于用户传入 trainer 的每一个 optimizer,检测其是否合理,因为不同的深度学习框架所使用的的 optimizer 是不相同的; | |||
:param optimizers: 需要检测的 `optimizers`; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||
raise NotImplementedError( | |||
"Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||
def set_optimizers(self, optimizers=None): | |||
""" | |||
r""" | |||
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; | |||
:param optimizers: | |||
:return: | |||
""" | |||
self.optimizers = optimizers | |||
@abstractmethod | |||
def backward(self, loss): | |||
""" | |||
r""" | |||
实现深度学习中的反向传播过程; | |||
:param loss: 用来实现反向传播的损失函数值; | |||
@@ -219,7 +226,7 @@ class Driver(ABC): | |||
@property | |||
def auto_cast(self): | |||
""" | |||
r""" | |||
fp16 的上下文环境; | |||
:return: 返回一个用于 fp16 计算的上下文环境; | |||
@@ -246,7 +253,7 @@ class Driver(ABC): | |||
r""" | |||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 BytesIO 对象; | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; | |||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | |||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | |||
:return: 返回加载指定文件后的结果; | |||
@@ -254,7 +261,8 @@ class Driver(ABC): | |||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | |||
@abstractmethod | |||
def save_checkpoint(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
def save_checkpoint(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, | |||
**kwargs): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||
@@ -271,7 +279,8 @@ class Driver(ABC): | |||
raise NotImplementedError("Each specific driver should implemented its own `save_checkpoint` function.") | |||
@abstractmethod | |||
def load_checkpoint(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
def load_checkpoint(self, folder: Union[str, Path], dataloader, only_state_dict: bool = True, should_load_model: bool = True, | |||
**kwargs) -> Dict: | |||
r""" | |||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | |||
其它在 Driver.save_checkpoint() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save_checkpoint() 接受到的 states )。 | |||
@@ -287,28 +296,30 @@ class Driver(ABC): | |||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | |||
找到对应的模型状态,则报错。 | |||
:return: 需要返回 save_checkpoint 函数输入的 states 内容 | |||
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||
在保存与当前传入 data sample 数目不一致时报错。 | |||
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||
* *dataloader* -- 返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||
在保存与当前传入 data sample 数目不一致时报错。 | |||
* *batch_idx_in_epoch* -- int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `load_checkpoint` function.") | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce: Optional[str]=None): | |||
def tensor_to_numeric(tensor, reduce: Optional[str] = None): | |||
r""" | |||
将一个 `tensor` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||
元素则返回 float 或 int 。 | |||
将一个 ``tensor`` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 ``numeric`` 对象;如果 ``tensor`` 只包含一个 | |||
元素则返回 ``float`` 或 ``int``; | |||
:param tensor: 需要被转换的 `tensor` 对象 | |||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||
float 或 int 对象。 | |||
:return: 转换后返回的结果 | |||
:param tensor: 需要被转换的 `tensor` 对象; | |||
:param reduce: 可选 ``['sum', 'max', 'mea', 'min']``,如果不为 ``None`` 将使用该 ``reduce`` 方法来处理当前 ``tensor`` 再返回 | |||
``float`` 或 ``int`` 对象; | |||
:return: 转换后返回的结果; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | |||
@@ -321,7 +332,7 @@ class Driver(ABC): | |||
""" | |||
def unwrap_model(self): | |||
""" | |||
r""" | |||
保证用户拿到的模型一定是最原始的模型; | |||
注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是 | |||
最为原始的模型; | |||
@@ -342,14 +353,14 @@ class Driver(ABC): | |||
@abstractmethod | |||
def move_data_to_device(self, batch): | |||
r""" | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构; | |||
:return: 将移动到指定机器上的 batch 对象返回; | |||
""" | |||
def get_local_rank(self) -> int: | |||
r""" | |||
返回当前的local_rank,本函数的返回值只在运行分布式训练的时候有实际含义。 | |||
返回当前的local_rank,本函数的返回值只在运行分布式训练的时候有实际含义; | |||
:return: 一个整数值,表示当前进程在当前这台机器上的序号; | |||
""" | |||
@@ -358,13 +369,13 @@ class Driver(ABC): | |||
def barrier(self): | |||
r""" | |||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||
仅在多分布式训练场景中有使用。 | |||
仅在多分布式训练场景中有使用; | |||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier; | |||
""" | |||
def is_distributed(self) -> bool: | |||
""" | |||
r""" | |||
当前的 driver 实例是否是分布式的; | |||
:return: 返回一个 bool 值,如果当前的 driver 实例是用于分布式的,那么返回 True; | |||
@@ -372,7 +383,7 @@ class Driver(ABC): | |||
return False | |||
def on_exception(self): | |||
""" | |||
r""" | |||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||
@@ -390,40 +401,38 @@ class Driver(ABC): | |||
'exc_local_rank': self.get_local_rank(), | |||
} | |||
sys.stderr.write("\nException info:\n") | |||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||
sys.stderr.write(json.dumps(_write_exc_info, indent=2) + "\n") | |||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | |||
for pid in self._pids: | |||
if pid != os.getpid(): | |||
os.kill(pid, signal.SIGKILL) | |||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||
""" | |||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )broadcast 到其它所有进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||
def broadcast_object(self, obj, src: int = 0, group=None, **kwargs): | |||
r""" | |||
从 ``src`` 端将 ``obj`` 对象(可能是 ``tensor``,可能是 ``object`` )broadcast 到其它所有进程。如果是非 ``tensor`` 的对象会尝试使用 ``pickle`` 进行打包进行 | |||
传输,然后再 ``dst`` 处再加载回来。仅在分布式的 ``driver`` 中有实际意义。 | |||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||
:param int src: source 的 global rank 。 | |||
:param group: 所属的 group | |||
:param kwargs: | |||
:return: 输入的 obj 。 | |||
:param obj: obj,可能是 ``Tensor`` 或 嵌套类型的数据; | |||
:param src: source 的 ``global rank``; | |||
:param group: 所属的通信组; | |||
:return: 输入的 ``obj``; | |||
""" | |||
if not self.is_distributed(): | |||
return obj | |||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `broadcast_object` method right " | |||
f"now.") | |||
def all_gather(self, obj, group)->List: | |||
""" | |||
def all_gather(self, obj, group) -> List: | |||
r""" | |||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||
pickle 进行序列化,接收到之后再反序列化。 | |||
:param obj: 可以是 float/int/bool/np.ndarray/{}/[]/Tensor等。 | |||
:param group: | |||
:return: 返回值应该是 [obj0, obj1, ...], 其中obj1是rank0上的对象,obj1是rank1上的对象... | |||
:param obj: 可以是 ``float/int/bool/np.ndarray/{}/[]/Tensor`` 等; | |||
:param group: 用于不同进程之间互相通信的通信组; | |||
:return: 返回值应该是 ``[obj0, obj1, ...]``,其中 ``obj1`` 是 ``rank0`` 上的对象,``obj1`` 是 ``rank1`` 上的对象; | |||
""" | |||
if not self.is_distributed(): | |||
return [obj] | |||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `all_gather` method right " | |||
f"now.") | |||
@@ -1,3 +1,130 @@ | |||
r""" | |||
""" | |||
r""" | |||
`TorchDDPDriver` 目前支持的三种启动方式: | |||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程, | |||
然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A) | |||
2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||
会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B) | |||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||
都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C) | |||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存 | |||
任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量); | |||
Part 1:三种启动方式的具体分析: | |||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||
`TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model), | |||
因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||
用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了; | |||
这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||
接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组; | |||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||
才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续 | |||
向前运行,直到其它进程也运行到这里; | |||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹; | |||
至此,ddp 的环境配置过程全部完成; | |||
-> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在 | |||
`TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||
这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu | |||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似; | |||
-> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。 | |||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||
启动方式来实现这一点的: | |||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||
使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中, | |||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||
Part 2:对应的代码细节: | |||
1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化); | |||
dist.is_initialized(); | |||
2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess` | |||
函数拉起; | |||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch` | |||
来拉起多个进程; | |||
3. 整体的处理判断流程: | |||
___________________________________ | |||
|进入 TorchDDPDriver 的 __init__ 函数| | |||
——————————————————————————————————— | |||
↓ | |||
___________________________________________________ | |||
| 判断不同的进程是否是由 torch.distributed.launch 拉起 | | |||
|(或者我们自己的 open_subprocess 函数拉起) | --------------> | |||
——————————————————————————————————————————————————— | | |||
↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程 | |||
___________________________ | | |||
←←←←← | 检测用户是否自己初始化了 ddp | | | |||
↓ ——————————————————————————— ↓ | |||
↓ ↓ 是 ________ | |||
↓ ______ | 情况 A | | |||
↓ 否 |情况 C| ————————— | |||
↓ ——————— | |||
↓ | |||
↓ ______ | |||
↓ -----------> |情况 B| | |||
——————— | |||
4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||
情况 A | 情况 B | 情况 C | |||
________________________________________________________________________________________________________ | |||
配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||
需要的环境变量 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
调用 dist. | | | | |||
init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用 | |||
_group 函数 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
设置 TorchDDPDriver | | | | |||
的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup | |||
global_rank 属性 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
Part 3:其它的处理细节: | |||
1. 环境变量; | |||
fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己 | |||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||
2. parallel_device, model_device 和 data_device 的关系; | |||
parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||
model_device 永远都为单独的一个 torch.device; | |||
情况 A | 情况 B | 情况 C | |||
________________________________________________________________________________________________________ | |||
parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device( | |||
| device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}") | |||
| 其中每一个对象都是 torch.device | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
model_device | parallel_device[local_rank] | parallel_device | None | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||
| | | data_device 决定 | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
3. _DDPWrappingModel 的作用; | |||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||
""" | |||
import os | |||
import sys | |||
import __main__ | |||
@@ -7,6 +134,7 @@ from time import sleep | |||
from typing import List, Optional, Union, Dict, Tuple, Callable | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
@@ -26,7 +154,8 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
) | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \ | |||
ReproducibleBatchSampler, \ | |||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||
from fastNLP.core.log import logger | |||
@@ -34,6 +163,81 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||
class TorchDDPDriver(TorchDriver): | |||
r""" | |||
``TorchDDPDriver`` 通过开启多个进程,让每个进程单独使用一个 gpu 设备来实现分布式训练; | |||
.. note:: | |||
您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; | |||
``TorchDDPDriver`` 目前支持的三种启动方式: | |||
1. 用户自己不进行 ``ddp`` 的任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程, | |||
然后 ``TorchDDPDriver`` 自己通过调用 ``dist.init_process_group`` 来初始化 ddp 的通信组;(情况 A) | |||
.. code-block:: | |||
trainer = Trainer( | |||
... | |||
driver='torch', | |||
device=[0, 1] | |||
) | |||
trainer.run() | |||
通过运行 ``python train.py`` 启动; | |||
2. 用户同样不在 ``Trainer`` 之外初始化 ``ddp``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧 | |||
会通过调用 ``dist.init_process_group`` 来初始化 ``ddp`` 的通信组;(情况 B) | |||
.. code-block:: | |||
trainer = Trainer( | |||
... | |||
driver='torch', | |||
device=None | |||
) | |||
trainer.run() | |||
通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||
3. 用户自己在外面初始化 ``DDP``,并且通过 ``python -m torch.distributed.launch`` 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||
都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``TorchDDPDriver`` 设置一些必要的属性值;(情况 C) | |||
.. code-block:: | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
# 获取当前的进程信息; | |||
... | |||
# 初始化 ddp 不同进程间的通信组; | |||
dist.init_process_group(...) | |||
# 初始化模型使用 DistributedDataParallel 包裹; | |||
model = Model() | |||
model = DistributedDataParallel(model, ...) | |||
# 注意此时仍旧不需要您主动地将 datalaoder 的 sampler 替换为 DistributedSampler; | |||
trainer = Trainer( | |||
... | |||
driver='torch', | |||
device=None | |||
) | |||
trainer.run() | |||
通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||
注意多机的启动强制要求用户在每一台机器上使用 ``python -m torch.distributed.launch`` 启动;因此我们不会在 ``TorchDDPDriver`` 中保存 | |||
任何当前有多少台机器的信息; | |||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||
:param parallel_device: 用于分布式训练的 ``gpu`` 设备; | |||
:param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的; | |||
:param fp16: 是否开启 fp16 训练; | |||
:param kwargs: 其余的一些用于设定 ddp 训练的参数; | |||
""" | |||
def __init__( | |||
self, | |||
model, | |||
@@ -42,129 +246,7 @@ class TorchDDPDriver(TorchDriver): | |||
fp16: bool = False, | |||
**kwargs | |||
): | |||
r""" | |||
`TorchDDPDriver` 目前支持的三种启动方式: | |||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程, | |||
然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A) | |||
2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||
会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B) | |||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||
都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C) | |||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存 | |||
任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量); | |||
Part 1:三种启动方式的具体分析: | |||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||
`TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model), | |||
因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||
用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了; | |||
这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||
接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组; | |||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||
才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续 | |||
向前运行,直到其它进程也运行到这里; | |||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹; | |||
至此,ddp 的环境配置过程全部完成; | |||
-> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在 | |||
`TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||
这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu | |||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似; | |||
-> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。 | |||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||
启动方式来实现这一点的: | |||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||
使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中, | |||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||
Part 2:对应的代码细节: | |||
1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化); | |||
dist.is_initialized(); | |||
2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess` | |||
函数拉起; | |||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch` | |||
来拉起多个进程; | |||
3. 整体的处理判断流程: | |||
___________________________________ | |||
|进入 TorchDDPDriver 的 __init__ 函数| | |||
——————————————————————————————————— | |||
↓ | |||
___________________________________________________ | |||
| 判断不同的进程是否是由 torch.distributed.launch 拉起 | | |||
|(或者我们自己的 open_subprocess 函数拉起) | --------------> | |||
——————————————————————————————————————————————————— | | |||
↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程 | |||
___________________________ | | |||
←←←←← | 检测用户是否自己初始化了 ddp | | | |||
↓ ——————————————————————————— ↓ | |||
↓ ↓ 是 ________ | |||
↓ ______ | 情况 A | | |||
↓ 否 |情况 C| ————————— | |||
↓ ——————— | |||
↓ | |||
↓ ______ | |||
↓ -----------> |情况 B| | |||
——————— | |||
4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||
情况 A | 情况 B | 情况 C | |||
________________________________________________________________________________________________________ | |||
配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||
需要的环境变量 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
调用 dist. | | | | |||
init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用 | |||
_group 函数 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
设置 TorchDDPDriver | | | | |||
的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup | |||
global_rank 属性 | | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
Part 3:其它的处理细节: | |||
1. 环境变量; | |||
fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己 | |||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||
2. parallel_device, model_device 和 data_device 的关系; | |||
parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||
model_device 永远都为单独的一个 torch.device; | |||
情况 A | 情况 B | 情况 C | |||
________________________________________________________________________________________________________ | |||
parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device( | |||
| device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}") | |||
| 其中每一个对象都是 torch.device | | | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
model_device | parallel_device[local_rank] | parallel_device | None | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||
| | | data_device 决定 | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
3. _DDPWrappingModel 的作用; | |||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||
""" | |||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
@@ -176,8 +258,9 @@ class TorchDDPDriver(TorchDriver): | |||
self.is_pull_by_torch_run = is_pull_by_torch_run | |||
self.parallel_device = parallel_device | |||
if not is_pull_by_torch_run and parallel_device is None: | |||
raise ValueError("Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
raise ValueError( | |||
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | |||
if is_pull_by_torch_run: | |||
@@ -233,10 +316,16 @@ class TorchDDPDriver(TorchDriver): | |||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||
def setup(self): | |||
r""" | |||
准备分布式环境,该函数主要做以下两件事情: | |||
1. 开启多进程,每个 gpu 设备对应单独的一个进程; | |||
2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型; | |||
""" | |||
if self._has_setup: | |||
return | |||
self._has_setup = True | |||
@@ -280,9 +369,10 @@ class TorchDDPDriver(TorchDriver): | |||
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; | |||
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | |||
if pre_num_processes != len(self.parallel_device): | |||
raise RuntimeError("Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||
"allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||
"`num_nodes` and `num_processes`.") | |||
raise RuntimeError( | |||
"Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||
"allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||
"`num_nodes` and `num_processes`.") | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
@@ -302,7 +392,7 @@ class TorchDDPDriver(TorchDriver): | |||
local_world_size = local_world_size.tolist() + 1 | |||
node_rank = self.global_rank // local_world_size | |||
self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | |||
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||
self._pids = self.tensor_to_numeric(self._pids) | |||
def configure_ddp(self): | |||
@@ -423,7 +513,8 @@ class TorchDDPDriver(TorchDriver): | |||
return self.model, model.forward | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
def set_dist_repro_dataloader(self, dataloader, | |||
dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
@@ -505,16 +596,26 @@ class TorchDDPDriver(TorchDriver): | |||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
raise ValueError( | |||
"Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
def is_global_zero(self): | |||
r""" | |||
:return: 返回当前的进程是否在全局上是进程 0 ; | |||
""" | |||
return self.global_rank == 0 | |||
def get_model_no_sync_context(self): | |||
r""" | |||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步; | |||
""" | |||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
r""" | |||
:return: 返回没有经过 ``DistributedDataParallel`` 包裹的原始模型; | |||
""" | |||
_module = self.model.module | |||
if isinstance(_module, _DDPWrappingModel): | |||
return _module.model | |||
@@ -522,17 +623,26 @@ class TorchDDPDriver(TorchDriver): | |||
return _module | |||
def get_local_rank(self) -> int: | |||
r""" | |||
:return: 返回当前进程局部的进程编号; | |||
""" | |||
return self.local_rank | |||
def barrier(self): | |||
r""" | |||
通过使用该函数来使得各个进程之间同步操作; | |||
""" | |||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||
torch.distributed.barrier(async_op=False) | |||
def is_distributed(self): | |||
r""" | |||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchDDPDriver`` 来说,该函数一定返回 ``True``; | |||
""" | |||
return True | |||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||
""" | |||
def broadcast_object(self, obj, src: int = 0, group=None, **kwargs): | |||
r""" | |||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||
@@ -540,7 +650,6 @@ class TorchDDPDriver(TorchDriver): | |||
:param int src: source 的 global rank 。 | |||
:param int dst: target 的 global rank,可以是多个目标 rank | |||
:param group: 所属的 group | |||
:param kwargs: | |||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||
""" | |||
@@ -549,7 +658,7 @@ class TorchDDPDriver(TorchDriver): | |||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | |||
def all_gather(self, obj, group) -> List: | |||
""" | |||
r""" | |||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||
pickle 进行序列化,接收到之后再反序列化。 | |||
@@ -578,10 +687,9 @@ class TorchDDPDriver(TorchDriver): | |||
def find_free_network_port() -> str: | |||
"""Finds a free port on localhost. | |||
It is useful in single-node training when we don't want to connect to a real master node but have to set the | |||
`MASTER_PORT` environment variable. | |||
""" | |||
在 localhost 上找到一个空闲端口; | |||
当我们不想连接到真正的主节点但必须设置“MASTER_PORT”环境变量时在单节点训练中很有用; | |||
""" | |||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
s.bind(("", 0)) | |||
@@ -145,6 +145,27 @@ def _tensor_to_object(tensor, tensor_size): | |||
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||
r""" | |||
pytorch 中的单点对多点的分发函数; | |||
例如将进程 0 上的对象 object 分发到其它进程上; | |||
Example:: | |||
cur_rank = int(os.environ.get('LOCAL_RANK', 0)) | |||
# 拿到 local_device | |||
send_recv_object(object, 0, cur_rank, local_device) | |||
:param obj: 一个可以序列化的 python 对象; | |||
:param src: 从哪一个 rank 上发送到其它 rank; | |||
:param cur_rank: 当前的进程的 rank 序号; | |||
:param device: 当前的进程所在的设备; | |||
:param group: 通信组,默认为 None; | |||
:param tag: 将发送与远程接收匹配的标记; | |||
:return: | |||
""" | |||
# src rank send to all other ranks | |||
size = torch.LongTensor([0]).to(device) | |||
@@ -25,7 +25,15 @@ from fastNLP.core.log import logger | |||
class TorchSingleDriver(TorchDriver): | |||
r""" | |||
用于 cpu 和 单卡 gpu 运算; | |||
``TorchSingleDriver`` 是用于 cpu 和 单卡 gpu 运算的 ``driver``; | |||
.. note:: | |||
如果您希望使用 ``DataParallel`` 来训练您的模型,您应当自己在 ``Trainer`` 初始化之前初始化好 ``DataParallel``,然后将其传入 ``Trainer`` 中; | |||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||
:param device: torch.device,当前进程所使用的设备; | |||
:param fp16: 是否开启 fp16; | |||
""" | |||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | |||
@@ -55,6 +63,9 @@ class TorchSingleDriver(TorchDriver): | |||
self.world_size = 1 | |||
def setup(self): | |||
r""" | |||
将模型迁移到相应的设备上; | |||
""" | |||
if self.model_device is not None: | |||
self.model.to(self.model_device) | |||
@@ -135,6 +146,9 @@ class TorchSingleDriver(TorchDriver): | |||
return dataloader | |||
def unwrap_model(self): | |||
r""" | |||
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; | |||
""" | |||
if isinstance(self.model, torch.nn.DataParallel) or \ | |||
isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | |||
return self.model.module | |||
@@ -143,10 +157,13 @@ class TorchSingleDriver(TorchDriver): | |||
@property | |||
def data_device(self): | |||
""" | |||
单卡模式不支持 data_device; | |||
r""" | |||
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``; | |||
""" | |||
return self.model_device | |||
def is_distributed(self): | |||
r""" | |||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``; | |||
""" | |||
return False |
@@ -36,7 +36,17 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, | |||
class TorchDriver(Driver): | |||
r""" | |||
专属于 pytorch 的 driver;因为我们会在同一个 Trainer 框架下提供 jittor、paddle 等训练框架的支持; | |||
专属于 ``pytorch`` 的 ``driver``,是 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 的父类; | |||
.. warning:: | |||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``TorchSingleDriver`` 和 ``TorchDDPDriver``,而不是 | |||
该类本身; | |||
.. note:: | |||
您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口; | |||
""" | |||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||
super(TorchDriver, self).__init__(model) | |||
@@ -111,7 +121,15 @@ class TorchDriver(Driver): | |||
f"not {type(each_optimizer)}.") | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
def tensor_to_numeric(tensor, reduce: str = None): | |||
r""" | |||
将 ``torch.Tensor`` 转换成 python 中的数值类型; | |||
:param tensor: ``torch.Tensor``; | |||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||
""" | |||
if tensor is None: | |||
return None | |||
@@ -129,6 +147,10 @@ class TorchDriver(Driver): | |||
) | |||
def set_model_mode(self, mode: str): | |||
r""" | |||
设置模型的状态是 ``train`` 还是 ``eval``; | |||
:param mode: ``train`` 或者 ``eval``; | |||
""" | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
@@ -326,14 +348,26 @@ class TorchDriver(Driver): | |||
return states | |||
def get_evaluate_context(self): | |||
r""" | |||
:return: 返回 ``torch.no_grad`` 这个 context; | |||
""" | |||
return torch.no_grad | |||
@staticmethod | |||
def move_model_to_device(model: "torch.nn.Module", device: "torch.device"): | |||
r""" | |||
将模型迁移到对应的设备上; | |||
""" | |||
if device is not None: | |||
model.to(device) | |||
def move_data_to_device(self, batch: "torch.Tensor"): | |||
def move_data_to_device(self, batch): | |||
""" | |||
将一个 batch 的数据迁移到对应的设备上; | |||
:param batch: 一个 batch 的数据,可以是 ``list、dict`` 等; | |||
:return: | |||
""" | |||
return torch_move_data_to_device(batch, self.data_device, self.non_blocking) | |||
@staticmethod | |||
@@ -174,7 +174,7 @@ def _build_fp16_env(dummy=False): | |||
def replace_sampler(dataloader: "DataLoader", sampler): | |||
""" | |||
r""" | |||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||
@@ -259,7 +259,7 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
def _dataloader_init_kwargs_resolve_sampler( | |||
dataloader: "DataLoader", sampler: Optional["Sampler"] | |||
) -> Dict[str, Any]: | |||
""" | |||
r""" | |||
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | |||
""" | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
@@ -279,15 +279,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||
def replace_batch_sampler(dataloader, new_batch_sampler): | |||
"""Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new | |||
dataloader with new batch sampler. | |||
Args: | |||
dataloader: input dataloader | |||
new_batch_sampler: new batch sampler to use | |||
Returns: | |||
DataLoader | |||
r""" | |||
替换一个 dataloader 的 batch_sampler; | |||
""" | |||
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | |||
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | |||
@@ -296,12 +289,16 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||
params = {k: getattr(dataloader, k) for k in params_keys} | |||
params["batch_sampler"] = new_batch_sampler | |||
return type(dataloader)(**params) | |||
# TODO 这里是否可以auto_param_call一下 | |||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||
# signature_fn=type(dataloader).__init__) | |||
def optimizer_state_to_device(state, device): | |||
r""" | |||
将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备; | |||
:param state: ``optimzier.state_dict()``; | |||
:param device: 要迁移到的目的设备; | |||
:return: 返回迁移后的新的 state_dict; | |||
""" | |||
new_state = {} | |||
for name, param in state.items(): | |||
if isinstance(param, dict): | |||
@@ -3,7 +3,7 @@ import subprocess | |||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | |||
""" | |||
r""" | |||
使用 command 通过 subprocess.Popen 开启新的进程。 | |||
:param output_from_new_proc: 可选 ["ignore", "all", "only_error"],以上三个为特殊关键字,分别表示完全忽略拉起进程的打印输出, | |||
@@ -11,8 +11,8 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy: | |||
两个文件,名称分别为 {rank}_std.log, {rank}_err.log 。原有的文件会被直接覆盖。 | |||
:param command: List[str] 启动的命令 | |||
:param env_copy: 需要注入的环境变量。 | |||
:param rank: | |||
:return: | |||
:param rank: global_rank; | |||
:return: 返回使用 ``subprocess.Popen`` 打开的进程; | |||
""" | |||
if output_from_new_proc == "all": | |||
proc = subprocess.Popen(command, env=env_copy) | |||
@@ -86,9 +86,11 @@ | |||
"\n", | |||
"  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", | |||
"\n", | |||
"注:在同一脚本中,`Trainer`和`Evaluator`使用的`driver`应当保持一致\n", | |||
"注:这里给出一条建议:**在同一脚本中**,**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n", | |||
"\n", | |||
"  一个不能违背的原则在于:**不要将多卡的`driver`前使用单卡的`driver`**(???),这样使用可能会带来很多意想不到的错误" | |||
"  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n", | |||
"\n", | |||
"  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦" | |||
] | |||
}, | |||
{ | |||
@@ -167,7 +169,7 @@ | |||
"\n", | |||
"注:在`fastNLP 0.8`中,**`Trainer`要求模型通过`train_step`来返回一个字典**,**满足如`{\"loss\": loss}`的形式**\n", | |||
"\n", | |||
"  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现高度化的定制,具体请见这一note(???)\n", | |||
"  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n", | |||
"\n", | |||
"同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n", | |||
"\n", | |||
@@ -177,7 +179,7 @@ | |||
"\n", | |||
"  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", | |||
"\n", | |||
"<img src=\"./figures/T0-fig-trainer-and-evaluator.png\" width=\"80%\" height=\"80%\" align=\"center\"></img>" | |||
"<img src=\"./figures/T0-fig-training-structure.png\" width=\"68%\" height=\"68%\" align=\"center\"></img>" | |||
] | |||
}, | |||
{ | |||
@@ -216,8 +218,14 @@ | |||
"\n", | |||
" def __getitem__(self, item):\n", | |||
" return {\"x\": self.x[item], \"y\": self.y[item]}\n", | |||
"```\n", | |||
"***\n", | |||
"```" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "f5f1a6aa", | |||
"metadata": {}, | |||
"source": [ | |||
"对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n", | |||
"\n", | |||
"    **`update`函数**,**针对一个`batch`的预测结果**,计算其累计的评价指标\n", | |||
@@ -230,7 +238,9 @@ | |||
"\n", | |||
"  在此基础上,**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n", | |||
"\n", | |||
"    **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名" | |||
"    **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名\n", | |||
"\n", | |||
"<img src=\"./figures/T0-fig-parameter-matching.png\" width=\"75%\" height=\"75%\" align=\"center\"></img>" | |||
] | |||
}, | |||
{ | |||
@@ -639,11 +649,11 @@ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.29</span><span style=\"font-weight: bold\">}</span>\n", | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.39</span><span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.29\u001b[0m\u001b[1m}\u001b[0m\n" | |||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.39\u001b[0m\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
@@ -652,7 +662,7 @@ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'acc#acc': 0.29}" | |||
"{'acc#acc': 0.39}" | |||
] | |||
}, | |||
"execution_count": 9, | |||
@@ -710,7 +720,9 @@ | |||
"source": [ | |||
"通过使用`Trainer`类的`run`函数,进行训练\n", | |||
"\n", | |||
"  还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测,默认为2" | |||
"  还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测,默认为2\n", | |||
"\n", | |||
"  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此试探性评测" | |||
] | |||
}, | |||
{ | |||
@@ -773,6 +785,14 @@ | |||
"source": [ | |||
"trainer.run()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "c4e9c619", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
@@ -153,7 +153,7 @@ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"2438703969992 2438374526920\n", | |||
"1608199516936 1607874531400\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| idx | sentence | words | num |\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
@@ -183,7 +183,7 @@ | |||
"id": "aa277674", | |||
"metadata": {}, | |||
"source": [ | |||
"  注二:在`fastNLP 0.8`中,**对`dataset`使用等号**,**其效果是传引用**,**而不是赋值**(???)\n", | |||
"  注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n", | |||
"\n", | |||
"    如下所示,**`dropped`和`dataset`具有相同`id`**,**对`dropped`执行删除操作`dataset`同时会被修改**" | |||
] | |||
@@ -198,7 +198,7 @@ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"2438374526920 2438374526920\n", | |||
"1607874531400 1607874531400\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| idx | sentence | words | num |\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
@@ -296,9 +296,9 @@ | |||
"\n", | |||
"在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", | |||
"\n", | |||
"  **`apply`和`apply_more`针对整条实例**,**`apply_field`和`apply_field_more`仅针对实例的部分字段**\n", | |||
"  **`apply`和`apply_more`输入整条实例**,**`apply_field`和`apply_field_more`仅输入实例的部分字段**\n", | |||
"\n", | |||
"  **`apply`和`apply_field`仅针对单个字段**,**`apply_more`和`apply_field_more`则可以针对多个字段**\n", | |||
"  **`apply`和`apply_field`仅输出单个字段**,**`apply_more`和`apply_field_more`则是输出多个字段**\n", | |||
"\n", | |||
"  **`apply`和`apply_field`返回的是个列表**,**`apply_more`和`apply_field_more`返回的是个字典**\n", | |||
"\n", | |||
@@ -311,14 +311,14 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"id": "72a0b5f9", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "", | |||
"model_id": "8532c5609a394c19b60315663a6f0f4a", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
@@ -328,42 +328,6 @@ | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| idx | sentence | words |\n", | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", | |||
"| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", | |||
"| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", | |||
"+-----+------------------------------+------------------------------+\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
@@ -384,57 +348,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"id": "b1a8631f", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| idx | sentence | words |\n", | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", | |||
"| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", | |||
"| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", | |||
"+-----+------------------------------+------------------------------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset = DataSet(data)\n", | |||
"\n", | |||
@@ -459,57 +376,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"id": "057c1d2c", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| idx | sentence | words |\n", | |||
"+-----+------------------------------+------------------------------+\n", | |||
"| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", | |||
"| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", | |||
"| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", | |||
"+-----+------------------------------+------------------------------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset = DataSet(data)\n", | |||
"dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words')\n", | |||
@@ -528,57 +398,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"id": "51e2f02c", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| idx | sentence | words | num |\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", | |||
"| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", | |||
"| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", | |||
"+-----+------------------------+------------------------+-----+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset = DataSet(data)\n", | |||
"dataset.apply_more(lambda ins:{'words': ins['sentence'].split(), 'num': len(ins['sentence'].split())})\n", | |||
@@ -597,57 +420,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": null, | |||
"id": "db4295d5", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| idx | sentence | words | num |\n", | |||
"+-----+------------------------+------------------------+-----+\n", | |||
"| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", | |||
"| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", | |||
"| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", | |||
"+-----+------------------------+------------------------+-----+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset = DataSet(data)\n", | |||
"dataset.apply_field_more(lambda sent:{'words': sent.split(), 'num': len(sent.split())}, \n", | |||
@@ -669,7 +445,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"execution_count": null, | |||
"id": "012f537c", | |||
"metadata": {}, | |||
"outputs": [], | |||
@@ -700,20 +476,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 13, | |||
"execution_count": null, | |||
"id": "a4c1c10d", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"dict_items([('sentence', 'This is an apple .'), ('words', ['This', 'is', 'an', 'apple', '.']), ('num', 5)])\n", | |||
"dict_keys(['sentence', 'words', 'num'])\n", | |||
"dict_values(['This is an apple .', ['This', 'is', 'an', 'apple', '.'], 5])\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"ins = Instance(sentence=\"This is an apple .\", words=['This', 'is', 'an', 'apple', '.'], num=5)\n", | |||
"\n", | |||
@@ -732,22 +498,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 14, | |||
"execution_count": null, | |||
"id": "55376402", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+--------------------+------------------------+-----+-----+\n", | |||
"| sentence | words | num | idx |\n", | |||
"+--------------------+------------------------+-----+-----+\n", | |||
"| This is an apple . | ['This', 'is', 'an'... | 5 | 0 |\n", | |||
"+--------------------+------------------------+-----+-----+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"ins.add_field(field_name='idx', field=0)\n", | |||
"print(ins)" | |||
@@ -767,44 +521,20 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 15, | |||
"execution_count": null, | |||
"id": "fe15f4c1", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'sentence': <fastNLP.core.dataset.field.FieldArray at 0x237ce26d388>,\n", | |||
" 'words': <fastNLP.core.dataset.field.FieldArray at 0x237ce26d408>,\n", | |||
" 'num': <fastNLP.core.dataset.field.FieldArray at 0x237ce26d488>}" | |||
] | |||
}, | |||
"execution_count": 15, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset.get_all_fields()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 16, | |||
"execution_count": null, | |||
"id": "5433815c", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"['num', 'sentence', 'words']" | |||
] | |||
}, | |||
"execution_count": 16, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"dataset.get_field_names()" | |||
] | |||
@@ -823,29 +553,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 17, | |||
"execution_count": null, | |||
"id": "25ce5488", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"3 False\n", | |||
"6 True\n", | |||
"+------------------------------+------------------------------+--------+\n", | |||
"| sentence | words | length |\n", | |||
"+------------------------------+------------------------------+--------+\n", | |||
"| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", | |||
"| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", | |||
"| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", | |||
"| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", | |||
"| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", | |||
"| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", | |||
"+------------------------------+------------------------------+--------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"print(len(dataset), dataset.has_field('length')) \n", | |||
"if 'num' in dataset:\n", | |||
@@ -877,21 +588,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 18, | |||
"execution_count": null, | |||
"id": "3515e096", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Vocabulary([]...)\n", | |||
"{'<pad>': 0, '<unk>': 1}\n", | |||
"<pad> 0\n", | |||
"<unk> 1\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP.core.vocabulary import Vocabulary\n", | |||
"\n", | |||
@@ -914,20 +614,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 19, | |||
"execution_count": null, | |||
"id": "88c7472a", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"5 Counter({'生活': 1, '就像': 1, '海洋': 1})\n", | |||
"6 Counter({'生活': 1, '就像': 1, '海洋': 1, '只有': 1})\n", | |||
"6 {'<pad>': 0, '<unk>': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5}\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"vocab.add_word_lst(['生活', '就像', '海洋'])\n", | |||
"print(len(vocab), vocab.word_count)\n", | |||
@@ -950,21 +640,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 20, | |||
"execution_count": null, | |||
"id": "3447acde", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"<pad> 0\n", | |||
"<unk> 1\n", | |||
"生活 2\n", | |||
"彼岸 1 False\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"print(vocab.to_word(0), vocab.to_index('<pad>'))\n", | |||
"print(vocab.to_word(1), vocab.to_index('<unk>'))\n", | |||
@@ -986,21 +665,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 21, | |||
"execution_count": null, | |||
"id": "490b101c", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"生活 2\n", | |||
"彼岸 12 True\n", | |||
"13 Counter({'人': 4, '生活': 2, '就像': 2, '海洋': 2, '只有': 2, '意志': 1, '坚强的': 1, '才': 1, '能': 1, '到达': 1, '彼岸': 1})\n", | |||
"13 {'<pad>': 0, '<unk>': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5, '人': 6, '意志': 7, '坚强的': 8, '才': 9, '能': 10, '到达': 11, '彼岸': 12}\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"vocab.add_word_lst(['生活', '就像', '海洋', '只有', '意志', '坚强的', '人', '人', '人', '人', '才', '能', '到达', '彼岸'])\n", | |||
"print(vocab.to_word(2), vocab.to_index('生活'))\n", | |||
@@ -1023,19 +691,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 22, | |||
"execution_count": null, | |||
"id": "a99ff909", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'positive': 0, 'negative': 1}\n", | |||
"ValueError: word `neutral` not in vocabulary\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"vocab = Vocabulary(unknown=None, padding=None)\n", | |||
"\n", | |||
@@ -1058,19 +717,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 23, | |||
"execution_count": null, | |||
"id": "432f74c1", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'<unk>': 0, 'positive': 1, 'negative': 2}\n", | |||
"0 <unk>\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"vocab = Vocabulary(unknown='<unk>', padding=None)\n", | |||
"\n", | |||
@@ -1096,92 +746,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 24, | |||
"execution_count": null, | |||
"id": "3dbd985d", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<div>\n", | |||
"<style scoped>\n", | |||
" .dataframe tbody tr th:only-of-type {\n", | |||
" vertical-align: middle;\n", | |||
" }\n", | |||
"\n", | |||
" .dataframe tbody tr th {\n", | |||
" vertical-align: top;\n", | |||
" }\n", | |||
"\n", | |||
" .dataframe thead th {\n", | |||
" text-align: right;\n", | |||
" }\n", | |||
"</style>\n", | |||
"<table border=\"1\" class=\"dataframe\">\n", | |||
" <thead>\n", | |||
" <tr style=\"text-align: right;\">\n", | |||
" <th></th>\n", | |||
" <th>SentenceId</th>\n", | |||
" <th>Sentence</th>\n", | |||
" <th>Sentiment</th>\n", | |||
" </tr>\n", | |||
" </thead>\n", | |||
" <tbody>\n", | |||
" <tr>\n", | |||
" <th>0</th>\n", | |||
" <td>1</td>\n", | |||
" <td>A series of escapades demonstrating the adage ...</td>\n", | |||
" <td>negative</td>\n", | |||
" </tr>\n", | |||
" <tr>\n", | |||
" <th>1</th>\n", | |||
" <td>2</td>\n", | |||
" <td>This quiet , introspective and entertaining in...</td>\n", | |||
" <td>positive</td>\n", | |||
" </tr>\n", | |||
" <tr>\n", | |||
" <th>2</th>\n", | |||
" <td>3</td>\n", | |||
" <td>Even fans of Ismail Merchant 's work , I suspe...</td>\n", | |||
" <td>negative</td>\n", | |||
" </tr>\n", | |||
" <tr>\n", | |||
" <th>3</th>\n", | |||
" <td>4</td>\n", | |||
" <td>A positively thrilling combination of ethnogra...</td>\n", | |||
" <td>neutral</td>\n", | |||
" </tr>\n", | |||
" <tr>\n", | |||
" <th>4</th>\n", | |||
" <td>5</td>\n", | |||
" <td>A comedy-drama of nearly epic proportions root...</td>\n", | |||
" <td>positive</td>\n", | |||
" </tr>\n", | |||
" <tr>\n", | |||
" <th>5</th>\n", | |||
" <td>6</td>\n", | |||
" <td>The Importance of Being Earnest , so thick wit...</td>\n", | |||
" <td>neutral</td>\n", | |||
" </tr>\n", | |||
" </tbody>\n", | |||
"</table>\n", | |||
"</div>" | |||
], | |||
"text/plain": [ | |||
" SentenceId Sentence Sentiment\n", | |||
"0 1 A series of escapades demonstrating the adage ... negative\n", | |||
"1 2 This quiet , introspective and entertaining in... positive\n", | |||
"2 3 Even fans of Ismail Merchant 's work , I suspe... negative\n", | |||
"3 4 A positively thrilling combination of ethnogra... neutral\n", | |||
"4 5 A comedy-drama of nearly epic proportions root... positive\n", | |||
"5 6 The Importance of Being Earnest , so thick wit... neutral" | |||
] | |||
}, | |||
"execution_count": 24, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"import pandas as pd\n", | |||
"\n", | |||
@@ -1199,60 +767,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 25, | |||
"execution_count": null, | |||
"id": "4f634586", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+------------+------------------------------+-----------+\n", | |||
"| SentenceId | Sentence | Sentiment |\n", | |||
"+------------+------------------------------+-----------+\n", | |||
"| 1 | ['a', 'series', 'of', 'es... | negative |\n", | |||
"| 2 | ['this', 'quiet', ',', 'i... | positive |\n", | |||
"| 3 | ['even', 'fans', 'of', 'i... | negative |\n", | |||
"| 4 | ['a', 'positively', 'thri... | neutral |\n", | |||
"| 5 | ['a', 'comedy-drama', 'of... | positive |\n", | |||
"| 6 | ['the', 'importance', 'of... | neutral |\n", | |||
"+------------+------------------------------+-----------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP.core.dataset import DataSet\n", | |||
"\n", | |||
@@ -1273,7 +791,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 26, | |||
"execution_count": null, | |||
"id": "46722efc", | |||
"metadata": {}, | |||
"outputs": [], | |||
@@ -1297,55 +815,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 27, | |||
"execution_count": null, | |||
"id": "a2de615b", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Counter({'a': 9, 'of': 9, ',': 7, 'the': 6, '.': 5, 'is': 3, 'and': 3, 'good': 2, 'for': 2, 'which': 2, 'this': 2, \"'s\": 2, 'series': 1, 'escapades': 1, 'demonstrating': 1, 'adage': 1, 'that': 1, 'what': 1, 'goose': 1, 'also': 1, 'gander': 1, 'some': 1, 'occasionally': 1, 'amuses': 1, 'but': 1, 'none': 1, 'amounts': 1, 'to': 1, 'much': 1, 'story': 1, 'quiet': 1, 'introspective': 1, 'entertaining': 1, 'independent': 1, 'worth': 1, 'seeking': 1, 'even': 1, 'fans': 1, 'ismail': 1, 'merchant': 1, 'work': 1, 'i': 1, 'suspect': 1, 'would': 1, 'have': 1, 'hard': 1, 'time': 1, 'sitting': 1, 'through': 1, 'one': 1, 'positively': 1, 'thrilling': 1, 'combination': 1, 'ethnography': 1, 'all': 1, 'intrigue': 1, 'betrayal': 1, 'deceit': 1, 'murder': 1, 'shakespearean': 1, 'tragedy': 1, 'or': 1, 'juicy': 1, 'soap': 1, 'opera': 1, 'comedy-drama': 1, 'nearly': 1, 'epic': 1, 'proportions': 1, 'rooted': 1, 'in': 1, 'sincere': 1, 'performance': 1, 'by': 1, 'title': 1, 'character': 1, 'undergoing': 1, 'midlife': 1, 'crisis': 1, 'importance': 1, 'being': 1, 'earnest': 1, 'so': 1, 'thick': 1, 'with': 1, 'wit': 1, 'it': 1, 'plays': 1, 'like': 1, 'reading': 1, 'from': 1, 'bartlett': 1, 'familiar': 1, 'quotations': 1}) \n", | |||
"\n", | |||
"{'<pad>': 0, '<unk>': 1, 'a': 2, 'of': 3, ',': 4, 'the': 5, '.': 6, 'is': 7, 'and': 8, 'good': 9, 'for': 10, 'which': 11, 'this': 12, \"'s\": 13, 'series': 14, 'escapades': 15, 'demonstrating': 16, 'adage': 17, 'that': 18, 'what': 19, 'goose': 20, 'also': 21, 'gander': 22, 'some': 23, 'occasionally': 24, 'amuses': 25, 'but': 26, 'none': 27, 'amounts': 28, 'to': 29, 'much': 30, 'story': 31, 'quiet': 32, 'introspective': 33, 'entertaining': 34, 'independent': 35, 'worth': 36, 'seeking': 37, 'even': 38, 'fans': 39, 'ismail': 40, 'merchant': 41, 'work': 42, 'i': 43, 'suspect': 44, 'would': 45, 'have': 46, 'hard': 47, 'time': 48, 'sitting': 49, 'through': 50, 'one': 51, 'positively': 52, 'thrilling': 53, 'combination': 54, 'ethnography': 55, 'all': 56, 'intrigue': 57, 'betrayal': 58, 'deceit': 59, 'murder': 60, 'shakespearean': 61, 'tragedy': 62, 'or': 63, 'juicy': 64, 'soap': 65, 'opera': 66, 'comedy-drama': 67, 'nearly': 68, 'epic': 69, 'proportions': 70, 'rooted': 71, 'in': 72, 'sincere': 73, 'performance': 74, 'by': 75, 'title': 76, 'character': 77, 'undergoing': 78, 'midlife': 79, 'crisis': 80, 'importance': 81, 'being': 82, 'earnest': 83, 'so': 84, 'thick': 85, 'with': 86, 'wit': 87, 'it': 88, 'plays': 89, 'like': 90, 'reading': 91, 'from': 92, 'bartlett': 93, 'familiar': 94, 'quotations': 95} \n", | |||
"\n", | |||
"Vocabulary(['a', 'series', 'of', 'escapades', 'demonstrating']...)\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP.core.vocabulary import Vocabulary\n", | |||
"\n", | |||
@@ -1368,60 +841,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 28, | |||
"execution_count": null, | |||
"id": "2f9a04b2", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+------------+------------------------------+-----------+\n", | |||
"| SentenceId | Sentence | Sentiment |\n", | |||
"+------------+------------------------------+-----------+\n", | |||
"| 1 | [2, 14, 3, 15, 16, 5, 17,... | negative |\n", | |||
"| 2 | [12, 32, 4, 33, 8, 34, 35... | positive |\n", | |||
"| 3 | [38, 39, 3, 40, 41, 13, 4... | negative |\n", | |||
"| 4 | [2, 52, 53, 54, 3, 55, 8,... | neutral |\n", | |||
"| 5 | [2, 67, 3, 68, 69, 70, 71... | positive |\n", | |||
"| 6 | [5, 81, 3, 82, 83, 4, 84,... | neutral |\n", | |||
"+------------+------------------------------+-----------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"vocab.index_dataset(dataset, field_name='Sentence')\n", | |||
"print(dataset)" | |||
@@ -1437,67 +860,10 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 29, | |||
"execution_count": null, | |||
"id": "5f5eed18", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'negative': 0, 'positive': 1, 'neutral': 2}\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+------------+------------------------------+-----------+\n", | |||
"| SentenceId | Sentence | Sentiment |\n", | |||
"+------------+------------------------------+-----------+\n", | |||
"| 1 | [2, 14, 3, 15, 16, 5, 17,... | 0 |\n", | |||
"| 2 | [12, 32, 4, 33, 8, 34, 35... | 1 |\n", | |||
"| 3 | [38, 39, 3, 40, 41, 13, 4... | 0 |\n", | |||
"| 4 | [2, 52, 53, 54, 3, 55, 8,... | 2 |\n", | |||
"| 5 | [2, 67, 3, 68, 69, 70, 71... | 1 |\n", | |||
"| 6 | [5, 81, 3, 82, 83, 4, 84,... | 2 |\n", | |||
"+------------+------------------------------+-----------+\n" | |||
] | |||
} | |||
], | |||
"outputs": [], | |||
"source": [ | |||
"target_vocab = Vocabulary(padding=None, unknown=None)\n", | |||
"\n", | |||