@@ -10,7 +10,8 @@ from .utils import ( | |||
_MODE_PARAMETER, | |||
get_device_from_visible, | |||
reset_seed, | |||
replace_sampler | |||
replace_sampler, | |||
replace_batch_sampler, | |||
) | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
@@ -23,10 +24,12 @@ from fastNLP.core.utils import ( | |||
from fastNLP.core.samplers import ( | |||
RandomBatchSampler, | |||
ReproducibleSampler, | |||
ReproducibleIterator, | |||
ReproducibleBatchSampler, | |||
RandomSampler, | |||
UnrepeatedDistributedSampler, | |||
UnrepeatedSampler, | |||
UnrepeatedSequentialSampler, | |||
re_instantiate_sampler, | |||
conversion_between_reproducible_and_unrepeated_sampler, | |||
) | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | |||
from fastNLP.core.log import logger | |||
@@ -261,7 +264,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | |||
根据 paddle 设置的环境变量来获得各种属性 | |||
""" | |||
print("set_from_env") | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
@@ -325,23 +327,50 @@ class PaddleFleetDriver(PaddleDriver): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist, ReproducibleIterator): | |||
dist = re_instantiate_sampler(dist) | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
# trainer, evaluator | |||
# 自己初始化了分布式,什么都不做 | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
else: | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_sampler(dataloader, dist) | |||
return dataloader | |||
# trainer | |||
elif dist == "dist": | |||
args = self.get_dataloader_args(dataloader) | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(args.sampler, ReproducibleIterator): | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
batch_sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
@@ -364,10 +393,14 @@ class PaddleFleetDriver(PaddleDriver): | |||
# evaluator | |||
elif dist == "unrepeatdist": | |||
args = self.get_dataloader_args(dataloader) | |||
sampler = UnrepeatedDistributedSampler( | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
) | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
) | |||
else: | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
@@ -14,7 +14,7 @@ from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
@@ -178,11 +178,13 @@ class PaddleDriver(Driver): | |||
:param kwargs: | |||
:return: | |||
""" | |||
debug = kwargs.get("debug", False) | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
paddle.save(states, filepath) | |||
if debug: | |||
logger.debug("Save model state dict.") | |||
else: | |||
# paddle 在保存整个模型时需要传入额外参数 | |||
input_spec = kwargs.get("input_spec", None) | |||
@@ -196,6 +198,8 @@ class PaddleDriver(Driver): | |||
self.move_model_to_device(model, self.model_device) | |||
else: | |||
paddle.jit.save(model, filepath, input_spec) | |||
if debug: | |||
logger.debug("Save model.") | |||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
r""" | |||
@@ -207,11 +211,16 @@ class PaddleDriver(Driver): | |||
:param kwargs: | |||
:return: | |||
""" | |||
debug = kwargs.get("debug", False) | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
model.load_dict(paddle.load(filepath)) | |||
if debug: | |||
logger.debug("Load model state dict.") | |||
else: | |||
model.load_dict(paddle.jit.load(filepath).state_dict()) | |||
if debug: | |||
logger.debug("Load model.") | |||
@rank_zero_call | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
@@ -252,17 +261,7 @@ class PaddleDriver(Driver): | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||
logger.debug("Save model state dict") | |||
else: | |||
input_spec = kwargs.get("input_spec", None) | |||
if input_spec is None: | |||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) | |||
logger.debug("Save model") | |||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) | |||
# 3. 保存 optimizers 的状态; | |||
optimizers_state_dict = {} | |||
@@ -272,7 +271,7 @@ class PaddleDriver(Driver): | |||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
logger.debug("Save optimizer state dict") | |||
logger.debug("Save optimizer state dict.") | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
@@ -289,30 +288,23 @@ class PaddleDriver(Driver): | |||
# 2. 加载模型状态; | |||
if should_load_model: | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||
model.load_dict(res) | |||
logger.debug("Load model state dict.") | |||
else: | |||
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) | |||
logger.debug("Load model.") | |||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
sampler = dataloader_args.sampler | |||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||
if self.is_distributed(): | |||
raise RuntimeError( | |||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
sampler = dataloader_args.sampler | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = ReproducibleBatchSampler( | |||
batch_sampler=sampler, | |||
batch_size=dataloader_args.batch_sampler.batch_size, | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
) | |||
sampler.load_state_dict(states['sampler_states']) | |||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
@@ -420,6 +412,7 @@ class PaddleDriver(Driver): | |||
res.dataset = dataloader.dataset | |||
if dataloader.batch_sampler is not None: | |||
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None | |||
res.batch_sampler = dataloader.batch_sampler | |||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||
@@ -11,7 +11,7 @@ from fastNLP.core.utils import ( | |||
get_paddle_device_id, | |||
paddle_move_data_to_device, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -137,55 +137,34 @@ class PaddleSingleDriver(PaddleDriver): | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
<<<<<<< HEAD | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
======= | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持IteratorDataset | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
"FastNLP does not support `IteratorDataset` now." | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
<<<<<<< HEAD | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleIterator): | |||
return replace_sampler(dataloader, dist) | |||
if reproducible: | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.sampler, ReproducibleIterator): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
======= | |||
dataloader.batch_sampler = dist | |||
return dataloader | |||
if isinstance(dist, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
elif isinstance(dist, ReproducibleSampler): | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
return dataloader | |||