@@ -10,7 +10,8 @@ from .utils import ( | |||||
_MODE_PARAMETER, | _MODE_PARAMETER, | ||||
get_device_from_visible, | get_device_from_visible, | ||||
reset_seed, | reset_seed, | ||||
replace_sampler | |||||
replace_sampler, | |||||
replace_batch_sampler, | |||||
) | ) | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -23,10 +24,12 @@ from fastNLP.core.utils import ( | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomBatchSampler, | RandomBatchSampler, | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
ReproducibleIterator, | |||||
ReproducibleBatchSampler, | |||||
RandomSampler, | RandomSampler, | ||||
UnrepeatedDistributedSampler, | |||||
UnrepeatedSampler, | |||||
UnrepeatedSequentialSampler, | |||||
re_instantiate_sampler, | re_instantiate_sampler, | ||||
conversion_between_reproducible_and_unrepeated_sampler, | |||||
) | ) | ||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -261,7 +264,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | ||||
根据 paddle 设置的环境变量来获得各种属性 | 根据 paddle 设置的环境变量来获得各种属性 | ||||
""" | """ | ||||
print("set_from_env") | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
@@ -325,23 +327,50 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleIterator): | |||||
dist = re_instantiate_sampler(dist) | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | # trainer, evaluator | ||||
# 自己初始化了分布式,什么都不做 | |||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
batch_sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | sampler = re_instantiate_sampler(args.sampler) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -364,10 +393,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
sampler = UnrepeatedDistributedSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
@@ -14,7 +14,7 @@ from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -178,11 +178,13 @@ class PaddleDriver(Driver): | |||||
:param kwargs: | :param kwargs: | ||||
:return: | :return: | ||||
""" | """ | ||||
debug = kwargs.get("debug", False) | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if only_state_dict: | if only_state_dict: | ||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | ||||
paddle.save(states, filepath) | paddle.save(states, filepath) | ||||
if debug: | |||||
logger.debug("Save model state dict.") | |||||
else: | else: | ||||
# paddle 在保存整个模型时需要传入额外参数 | # paddle 在保存整个模型时需要传入额外参数 | ||||
input_spec = kwargs.get("input_spec", None) | input_spec = kwargs.get("input_spec", None) | ||||
@@ -196,6 +198,8 @@ class PaddleDriver(Driver): | |||||
self.move_model_to_device(model, self.model_device) | self.move_model_to_device(model, self.model_device) | ||||
else: | else: | ||||
paddle.jit.save(model, filepath, input_spec) | paddle.jit.save(model, filepath, input_spec) | ||||
if debug: | |||||
logger.debug("Save model.") | |||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | ||||
r""" | r""" | ||||
@@ -207,11 +211,16 @@ class PaddleDriver(Driver): | |||||
:param kwargs: | :param kwargs: | ||||
:return: | :return: | ||||
""" | """ | ||||
debug = kwargs.get("debug", False) | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if only_state_dict: | if only_state_dict: | ||||
model.load_dict(paddle.load(filepath)) | model.load_dict(paddle.load(filepath)) | ||||
if debug: | |||||
logger.debug("Load model state dict.") | |||||
else: | else: | ||||
model.load_dict(paddle.jit.load(filepath).state_dict()) | model.load_dict(paddle.jit.load(filepath).state_dict()) | ||||
if debug: | |||||
logger.debug("Load model.") | |||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | ||||
@@ -252,17 +261,7 @@ class PaddleDriver(Driver): | |||||
# 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
if should_save_model: | if should_save_model: | ||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
logger.debug("Save model state dict") | |||||
else: | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||||
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) | |||||
logger.debug("Save model") | |||||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) | |||||
# 3. 保存 optimizers 的状态; | # 3. 保存 optimizers 的状态; | ||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
@@ -272,7 +271,7 @@ class PaddleDriver(Driver): | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | ||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | ||||
logger.debug("Save optimizer state dict") | |||||
logger.debug("Save optimizer state dict.") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
@@ -289,30 +288,23 @@ class PaddleDriver(Driver): | |||||
# 2. 加载模型状态; | # 2. 加载模型状态; | ||||
if should_load_model: | if should_load_model: | ||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
model.load_dict(res) | |||||
logger.debug("Load model state dict.") | |||||
else: | |||||
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) | |||||
logger.debug("Load model.") | |||||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) | |||||
# 3. 恢复 sampler 的状态; | # 3. 恢复 sampler 的状态; | ||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.is_distributed(): | |||||
raise RuntimeError( | |||||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
else: | |||||
sampler = ReproducibleBatchSampler( | sampler = ReproducibleBatchSampler( | ||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_sampler.batch_size, | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
) | ) | ||||
sampler.load_state_dict(states['sampler_states']) | sampler.load_state_dict(states['sampler_states']) | ||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | ||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
@@ -420,6 +412,7 @@ class PaddleDriver(Driver): | |||||
res.dataset = dataloader.dataset | res.dataset = dataloader.dataset | ||||
if dataloader.batch_sampler is not None: | if dataloader.batch_sampler is not None: | ||||
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None | |||||
res.batch_sampler = dataloader.batch_sampler | res.batch_sampler = dataloader.batch_sampler | ||||
if hasattr(dataloader.batch_sampler, "batch_size"): | if hasattr(dataloader.batch_sampler, "batch_size"): | ||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | ||||
@@ -11,7 +11,7 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -137,55 +137,34 @@ class PaddleSingleDriver(PaddleDriver): | |||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | return paddle_move_data_to_device(batch, "gpu:0") | ||||
<<<<<<< HEAD | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
======= | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# 暂时不支持IteratorDataset | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | |||||
# 暂时不支持iterableDataset | |||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | |||||
"FastNLP does not support `IteratorDataset` now." | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
<<<<<<< HEAD | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleIterator): | |||||
return replace_sampler(dataloader, dist) | |||||
if reproducible: | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
======= | |||||
dataloader.batch_sampler = dist | |||||
return dataloader | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
return dataloader | |||||
else: | |||||
# TODO | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||