Browse Source

跟进断点重训

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
2366bc320b
3 changed files with 94 additions and 89 deletions
  1. +46
    -13
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +24
    -31
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  3. +24
    -45
      fastNLP/core/drivers/paddle_driver/single_device.py

+ 46
- 13
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -10,7 +10,8 @@ from .utils import (
_MODE_PARAMETER, _MODE_PARAMETER,
get_device_from_visible, get_device_from_visible,
reset_seed, reset_seed,
replace_sampler
replace_sampler,
replace_batch_sampler,
) )


from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@@ -23,10 +24,12 @@ from fastNLP.core.utils import (
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
RandomBatchSampler, RandomBatchSampler,
ReproducibleSampler, ReproducibleSampler,
ReproducibleIterator,
ReproducibleBatchSampler,
RandomSampler, RandomSampler,
UnrepeatedDistributedSampler,
UnrepeatedSampler,
UnrepeatedSequentialSampler,
re_instantiate_sampler, re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler,
) )
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger from fastNLP.core.log import logger
@@ -261,7 +264,6 @@ class PaddleFleetDriver(PaddleDriver):
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要
根据 paddle 设置的环境变量来获得各种属性 根据 paddle 设置的环境变量来获得各种属性
""" """
print("set_from_env")
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank() self.global_rank = dist.get_rank()


@@ -325,23 +327,50 @@ class PaddleFleetDriver(PaddleDriver):
# 暂时不支持iterableDataset # 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now." "FastNLP does not support `IteratorDataset` now."
if isinstance(dist, ReproducibleIterator):
dist = re_instantiate_sampler(dist)
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist) return replace_sampler(dataloader, dist)


# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
# trainer, evaluator # trainer, evaluator
# 自己初始化了分布式,什么都不做
if dist is None: if dist is None:
if reproducible: if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.") "control.")
else: else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
return dataloader return dataloader
# trainer # trainer
elif dist == "dist": elif dist == "dist":
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(args.sampler, ReproducibleIterator):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler) sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed( sampler.set_distributed(
num_replicas=self.world_size, num_replicas=self.world_size,
@@ -364,10 +393,14 @@ class PaddleFleetDriver(PaddleDriver):
# evaluator # evaluator
elif dist == "unrepeatdist": elif dist == "unrepeatdist":
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler(
dataset=args.dataset,
shuffle=args.shuffle,
)
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)
else:
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed( sampler.set_distributed(
num_replicas=self.world_size, num_replicas=self.world_size,
rank=self.global_rank rank=self.global_rank


+ 24
- 31
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
@@ -178,11 +178,13 @@ class PaddleDriver(Driver):
:param kwargs: :param kwargs:
:return: :return:
""" """
debug = kwargs.get("debug", False)
model = self.unwrap_model() model = self.unwrap_model()

if only_state_dict: if only_state_dict:
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(states, filepath) paddle.save(states, filepath)
if debug:
logger.debug("Save model state dict.")
else: else:
# paddle 在保存整个模型时需要传入额外参数 # paddle 在保存整个模型时需要传入额外参数
input_spec = kwargs.get("input_spec", None) input_spec = kwargs.get("input_spec", None)
@@ -196,6 +198,8 @@ class PaddleDriver(Driver):
self.move_model_to_device(model, self.model_device) self.move_model_to_device(model, self.model_device)
else: else:
paddle.jit.save(model, filepath, input_spec) paddle.jit.save(model, filepath, input_spec)
if debug:
logger.debug("Save model.")


def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r""" r"""
@@ -207,11 +211,16 @@ class PaddleDriver(Driver):
:param kwargs: :param kwargs:
:return: :return:
""" """
debug = kwargs.get("debug", False)
model = self.unwrap_model() model = self.unwrap_model()
if only_state_dict: if only_state_dict:
model.load_dict(paddle.load(filepath)) model.load_dict(paddle.load(filepath))
if debug:
logger.debug("Load model state dict.")
else: else:
model.load_dict(paddle.jit.load(filepath).state_dict()) model.load_dict(paddle.jit.load(filepath).state_dict())
if debug:
logger.debug("Load model.")


@rank_zero_call @rank_zero_call
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
@@ -252,17 +261,7 @@ class PaddleDriver(Driver):


# 2. 保存模型的状态; # 2. 保存模型的状态;
if should_save_model: if should_save_model:
model = self.unwrap_model()
if only_state_dict:
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME))
logger.debug("Save model state dict")
else:
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec)
logger.debug("Save model")
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs)


# 3. 保存 optimizers 的状态; # 3. 保存 optimizers 的状态;
optimizers_state_dict = {} optimizers_state_dict = {}
@@ -272,7 +271,7 @@ class PaddleDriver(Driver):
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;


logger.debug("Save optimizer state dict")
logger.debug("Save optimizer state dict.")
states["optimizers_state_dict"] = optimizers_state_dict states["optimizers_state_dict"] = optimizers_state_dict
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))


@@ -289,30 +288,23 @@ class PaddleDriver(Driver):


# 2. 加载模型状态; # 2. 加载模型状态;
if should_load_model: if should_load_model:
model = self.unwrap_model()
if only_state_dict:
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME))
model.load_dict(res)
logger.debug("Load model state dict.")
else:
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict())
logger.debug("Load model.")
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True)


# 3. 恢复 sampler 的状态; # 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader) dataloader_args = self.get_dataloader_args(dataloader)
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = ReproducibleBatchSampler( sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_sampler.batch_size,
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
) )
sampler.load_state_dict(states['sampler_states']) sampler.load_state_dict(states['sampler_states'])

states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)


# 4. 修改 trainer_state.batch_idx_in_epoch # 4. 修改 trainer_state.batch_idx_in_epoch
@@ -420,6 +412,7 @@ class PaddleDriver(Driver):
res.dataset = dataloader.dataset res.dataset = dataloader.dataset


if dataloader.batch_sampler is not None: if dataloader.batch_sampler is not None:
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None
res.batch_sampler = dataloader.batch_sampler res.batch_sampler = dataloader.batch_sampler
if hasattr(dataloader.batch_sampler, "batch_size"): if hasattr(dataloader.batch_sampler, "batch_size"):
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") res.batch_size = getattr(dataloader.batch_sampler, "batch_size")


+ 24
- 45
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -11,7 +11,7 @@ from fastNLP.core.utils import (
get_paddle_device_id, get_paddle_device_id,
paddle_move_data_to_device, paddle_move_data_to_device,
) )
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
@@ -137,55 +137,34 @@ class PaddleSingleDriver(PaddleDriver):
""" """
return paddle_move_data_to_device(batch, "gpu:0") return paddle_move_data_to_device(batch, "gpu:0")


<<<<<<< HEAD
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
=======
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):

# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
"FastNLP does not support `IteratorDataset` now."
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
<<<<<<< HEAD
return replace_batch_sampler(dataloader, dist) return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
return replace_sampler(dataloader, dist)

if reproducible:
args = self.get_dataloader_args(dataloader)
if isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
else:
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
=======
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist
return dataloader
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)


if reproducible: if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
)
return replace_batch_sampler(dataloader, batch_sampler)
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else: else:
return dataloader return dataloader




Loading…
Cancel
Save