@@ -10,6 +10,7 @@ from .utils import ( | |||||
_MODE_PARAMETER, | _MODE_PARAMETER, | ||||
get_device_from_visible, | get_device_from_visible, | ||||
reset_seed, | reset_seed, | ||||
replace_sampler | |||||
) | ) | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -19,8 +20,13 @@ from fastNLP.core.utils import ( | |||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleIterator, | |||||
RandomSampler, | |||||
UnrepeatedDistributedSampler, | |||||
re_instantiate_sampler, | |||||
) | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -314,23 +320,15 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | ||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleIterator): | if isinstance(dist, ReproducibleIterator): | ||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员 | |||||
# 因此用 type() 进行严格的判断 | |||||
if type(dataloader.batch_sampler) == BatchSampler: | |||||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) | |||||
else: | |||||
shuffle = dataloader.batch_sampler.shuffle | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
# trainer, evaluator | # trainer, evaluator | ||||
# 自己初始化了分布式,什么都不做 | |||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | ||||
@@ -339,40 +337,40 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | ) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
args = self.get_dataloader_args(dataloader) | |||||
sampler = UnrepeatedDistributedSampler( | sampler = UnrepeatedDistributedSampler( | ||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
) | ) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
@@ -1,21 +1,31 @@ | |||||
import os | import os | ||||
import random | import random | ||||
from typing import Union, Optional, Callable, Dict | |||||
from typing import Union, Optional, Dict | |||||
from pathlib import Path | |||||
from functools import partial | from functools import partial | ||||
from dataclasses import dataclass | |||||
import numpy as np | import numpy as np | ||||
from .utils import _build_fp16_env | |||||
from .utils import _build_fp16_env, optimizer_state_to_device | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS | |||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, IterableDataset | |||||
from paddle.io import ( | |||||
DataLoader, | |||||
IterableDataset, | |||||
Dataset, | |||||
Sampler, | |||||
BatchSampler, | |||||
RandomSampler, | |||||
) | |||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
_reduces = { | _reduces = { | ||||
@@ -69,6 +79,8 @@ class PaddleDriver(Driver): | |||||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | ||||
if isinstance(dataloader.dataset, IterableDataset): | if isinstance(dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") | |||||
else: | else: | ||||
if not isinstance(dataloader, Dict): | if not isinstance(dataloader, Dict): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | ||||
@@ -79,6 +91,9 @@ class PaddleDriver(Driver): | |||||
f"type, not {type(each_dataloader)}.") | f"type, not {type(each_dataloader)}.") | ||||
if isinstance(each_dataloader.dataset, IterableDataset): | if isinstance(each_dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||||
f"`batch_sampler` and `batch_size` should be set.") | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
@@ -153,45 +168,53 @@ class PaddleDriver(Driver): | |||||
getattr(self.model, mode)() | getattr(self.model, mode)() | ||||
@rank_zero_call | @rank_zero_call | ||||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): | |||||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | ||||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | 如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | ||||
:param filepath: 保存文件的文件位置(需要包括文件名); | :param filepath: 保存文件的文件位置(需要包括文件名); | ||||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; | |||||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||||
:param kwargs: | |||||
:return: | |||||
""" | """ | ||||
if model_save_fn is not None: | |||||
model_save_fn(filepath) | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
paddle.save(states, filepath) | |||||
else: | else: | ||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
paddle.save(model.state_dict(), filepath) | |||||
# paddle 在保存整个模型时需要传入额外参数 | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||||
if self.model_device is not None: | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, "cpu") | |||||
paddle.jit.save(model, filepath, input_spec) | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, self.model_device) | |||||
else: | else: | ||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") | |||||
paddle.jit.save(model, filepath, input_spec) | paddle.jit.save(model, filepath, input_spec) | ||||
@staticmethod | |||||
@rank_zero_call | |||||
def load_model(filepath: str, load_dict: bool = True): | |||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | ||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | :param filepath: 需要被加载的对象的文件位置(需要包括文件名); | ||||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | :param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | ||||
即保存了整个模型时,这个参数必须也为False | 即保存了整个模型时,这个参数必须也为False | ||||
:return: 返回加载指定文件后的结果; | |||||
:param kwargs: | |||||
:return: | |||||
""" | """ | ||||
if load_dict: | |||||
return paddle.load(filepath) | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
model.load_dict(paddle.load(filepath)) | |||||
else: | else: | ||||
return paddle.jit.load(filepath) | |||||
model.load_dict(paddle.jit.load(filepath).state_dict()) | |||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder, states: Dict): | |||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | r""" | ||||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | ||||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | ||||
@@ -203,48 +226,110 @@ class PaddleDriver(Driver): | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | ||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | ||||
传入的值保持一致。 | 传入的值保持一致。 | ||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||||
:return: | |||||
""" | """ | ||||
# 1. 保存模型的状态; | |||||
model = self.unwrap_model() | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||||
states["model_state_dict"] = model_state_dict | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
# 2. 保存 optimizers 的状态; | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
logger.debug("Save model state dict") | |||||
else: | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||||
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) | |||||
logger.debug("Save model") | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: Optimizer = self.optimizers[i] | optimizer: Optimizer = self.optimizers[i] | ||||
optimizer_state = optimizer.state_dict() | optimizer_state = optimizer.state_dict() | ||||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | ||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, folder) | |||||
def load(self, filepath) -> Dict: | |||||
r""" | |||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; | |||||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 | |||||
因此 save 函数和 load 函数的接受和返回值应该是对应的; | |||||
该函数需要在所有 rank 上执行。 | |||||
logger.debug("Save optimizer state dict") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
:param filepath: 保存断点重训的状态的文件名; | |||||
:return: 需要返回 save 函数输入的 states 内容; | |||||
""" | |||||
states = paddle.load(filepath) | |||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
optimizers_state_dict = states["optimizers_state_dict"] | optimizers_state_dict = states["optimizers_state_dict"] | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] | |||||
optimizer: Optimizer = self.optimizers[i] | |||||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | ||||
logger.debug("Load optimizer state dict.") | |||||
# 2. 加载模型状态; | # 2. 加载模型状态; | ||||
model = self.unwrap_model() | |||||
model.load_dict(states["model_state_dict"]) | |||||
if should_load_model: | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
model.load_dict(res) | |||||
logger.debug("Load model state dict.") | |||||
else: | |||||
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) | |||||
logger.debug("Load model.") | |||||
# 3. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.is_distributed(): | |||||
raise RuntimeError( | |||||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_sampler.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
self.barrier() | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -313,3 +398,53 @@ class PaddleDriver(Driver): | |||||
""" | """ | ||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | dataloader.batch_sampler.set_epoch(cur_epoch_idx) | ||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# paddle 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
if dataloader.batch_sampler is not None: | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
res.sampler = None | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -2,6 +2,7 @@ import os | |||||
from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from .utils import replace_batch_sampler, replace_sampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.utils import ( | from fastNLP.core.utils import ( | ||||
@@ -10,7 +11,7 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -93,11 +94,8 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
def setup(self): | def setup(self): | ||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||||
device_id = get_paddle_device_id(self.model_device) | device_id = get_paddle_device_id(self.model_device) | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
device_id = user_visible_devices.split(",")[device_id] | |||||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | ||||
paddle.device.set_device("gpu:0") | paddle.device.set_device("gpu:0") | ||||
self.model.to("gpu:0") | self.model.to("gpu:0") | ||||
@@ -145,26 +143,25 @@ class PaddleSingleDriver(PaddleDriver): | |||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dataloader.batch_sampler = dist | |||||
return dataloader | |||||
if isinstance(dist, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleIterator): | |||||
return replace_sampler(dataloader, dist) | |||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
return dataloader | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | ||||
return dataloader | |||||
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
# TODO | |||||
batch_sampler = ReproducibleBatchSampler( | batch_sampler = ReproducibleBatchSampler( | ||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | ) | ||||
dataloader.batch_sampler = batch_sampler | |||||
return dataloader | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -9,7 +9,7 @@ from enum import IntEnum | |||||
from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call | |||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -272,11 +272,9 @@ def get_device_from_visible(device: Union[str, int]): | |||||
else: | else: | ||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
idx = user_visible_devices.split(",")[idx] | |||||
else: | |||||
idx = str(idx) | |||||
if user_visible_devices is None: | |||||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||||
idx = user_visible_devices.split(",")[idx] | |||||
cuda_visible_devices_list = cuda_visible_devices.split(',') | cuda_visible_devices_list = cuda_visible_devices.split(',') | ||||
assert idx in cuda_visible_devices_list, "Can't find "\ | assert idx in cuda_visible_devices_list, "Can't find "\ | ||||
@@ -285,31 +283,44 @@ def get_device_from_visible(device: Union[str, int]): | |||||
res = cuda_visible_devices_list.index(idx) | res = cuda_visible_devices_list.index(idx) | ||||
return res | return res | ||||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
# 拿到实例属性; | |||||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | |||||
""" | |||||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 | |||||
考虑了用户自己定制了 DataLoader 的情形。 | |||||
""" | |||||
# 拿到非下划线开头的实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | ||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | ||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | ||||
# 中寻找; | |||||
# 中寻找;VAR_KEYWORD 代表 **kwargs | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | ||||
del init_params["self"] | del init_params["self"] | ||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | ||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
non_default_params = {name for name, p in init_params.items() if | non_default_params = {name for name, p in init_params.items() if | ||||
name in instance_attrs and p.default != instance_attrs[name]} | name in instance_attrs and p.default != instance_attrs[name]} | ||||
# add `dataset` as it might have been replaced with `*args` | # add `dataset` as it might have been replaced with `*args` | ||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
# 收集不是默认值的参数和它的值 | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) | |||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
reconstruct_args.update({ | |||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||||
"persistent_workers": dataloader._persistent_workers, | |||||
}) | |||||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | |||||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | |||||
# 也即它们没有在初始化函数和实例成员中同时出现 | |||||
required_args = { | required_args = { | ||||
p.name | p.name | ||||
for p in init_params.values() | for p in init_params.values() | ||||
@@ -323,12 +334,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | "This would fail as some of the `__init__` arguments are not available as instance attributes. " | ||||
f"The missing attributes are {required_args}. " | f"The missing attributes are {required_args}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
@@ -340,12 +348,33 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | "This would fail as it doesn't expose all its attributes in the `__init__` signature. " | ||||
f"The missing arguments are {missing_kwargs}. " | f"The missing arguments are {missing_kwargs}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
def replace_sampler(dataloader, new_sampler): | |||||
""" | |||||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 | |||||
""" | |||||
new_batch_sampler = BatchSampler( | |||||
dataset=dataloader.batch_sampler.dataset, | |||||
sampler=new_sampler, | |||||
shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler), | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.batch_sampler.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, new_batch_sampler) | |||||
def optimizer_state_to_device(state, device): | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, paddle.Tensor): | |||||
new_state[name] = paddle_to(param, device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state |