@@ -184,7 +184,7 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
""" | """ | ||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
@@ -194,6 +194,16 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
""" | |||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
""" | """ | ||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
@@ -204,6 +214,16 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
""" | |||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
""" | """ | ||||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | ||||
@@ -92,8 +92,10 @@ class Events(EventEnum): | |||||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ||||
ON_BEFORE_BACKWARD = "on_before_backward" | ON_BEFORE_BACKWARD = "on_before_backward" | ||||
ON_AFTER_BACKWARD = "on_after_backward" | ON_AFTER_BACKWARD = "on_after_backward" | ||||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||||
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" | |||||
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" | |||||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ||||
ON_AFTER_ZERO_GRAD = "on_after_zero_grad" | |||||
ON_VALIDATE_BEGIN = "on_validate_begin" | ON_VALIDATE_BEGIN = "on_validate_begin" | ||||
ON_VALIDATE_END = "on_validate_end" | ON_VALIDATE_END = "on_validate_end" | ||||
@@ -278,13 +278,21 @@ class CallbackManager: | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
pass | pass | ||||
@_transfer | |||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | @_transfer | ||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
pass | pass | ||||
@@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
self.driver_name = driver.__class__.__name__ | self.driver_name = driver.__class__.__name__ | ||||
self.device = device | self.device = device | ||||
self.optimizers = optimizers | |||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
@@ -442,9 +443,11 @@ class Trainer(TrainerEventTrigger): | |||||
2. 函数作用 | 2. 函数作用 | ||||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | ||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") / | |||||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") | |||||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | ||||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | ||||
@@ -454,10 +457,12 @@ class Trainer(TrainerEventTrigger): | |||||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | ||||
""" | """ | ||||
if check_mode: | if check_mode: | ||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
else: | else: | ||||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
_not_called_callback_fns = [] | _not_called_callback_fns = [] | ||||
for each_callback_fn in callbacks: | for each_callback_fn in callbacks: | ||||
if each_callback_fn in self.callback_manager.callback_fns: | if each_callback_fn in self.callback_manager.callback_fns: | ||||
@@ -707,13 +712,15 @@ class Trainer(TrainerEventTrigger): | |||||
def zero_grad(self): | def zero_grad(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_zero_grad(self.driver.optimizers) | |||||
self.on_before_zero_grad(self.optimizers) | |||||
self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
self.on_after_zero_grad(self.optimizers) | |||||
def step(self): | def step(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_optimizer_step(self.driver.optimizers) | |||||
self.on_before_optimizers_step(self.optimizers) | |||||
self.driver.step() | self.driver.step() | ||||
self.on_after_optimizers_step(self.optimizers) | |||||
def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
return self.driver.move_data_to_device(batch) | return self.driver.move_data_to_device(batch) | ||||
@@ -825,3 +832,5 @@ class Trainer(TrainerEventTrigger): | |||||
@@ -68,12 +68,18 @@ class TrainerEventTrigger: | |||||
def on_after_backward(self): | def on_after_backward(self): | ||||
self.callback_manager.on_after_backward(self) | self.callback_manager.on_after_backward(self) | ||||
def on_before_optimizer_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||||
def on_before_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizers_step(self, optimizers) | |||||
def on_after_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_after_optimizers_step(self, optimizers) | |||||
def on_before_zero_grad(self, optimizers): | def on_before_zero_grad(self, optimizers): | ||||
self.callback_manager.on_before_zero_grad(self, optimizers) | self.callback_manager.on_before_zero_grad(self, optimizers) | ||||
def on_after_zero_grad(self, optimizers): | |||||
self.callback_manager.on_after_zero_grad(self, optimizers) | |||||
def on_validate_begin(self): | def on_validate_begin(self): | ||||
self.callback_manager.on_validate_begin(self) | self.callback_manager.on_validate_begin(self) | ||||
@@ -10,6 +10,8 @@ from .utils import ( | |||||
_MODE_PARAMETER, | _MODE_PARAMETER, | ||||
get_device_from_visible, | get_device_from_visible, | ||||
reset_seed, | reset_seed, | ||||
replace_sampler, | |||||
replace_batch_sampler, | |||||
) | ) | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -19,8 +21,17 @@ from fastNLP.core.utils import ( | |||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||||
from fastNLP.core.samplers import ( | |||||
RandomBatchSampler, | |||||
ReproducibleSampler, | |||||
ReproducibleBatchSampler, | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
UnrepeatedSequentialSampler, | |||||
re_instantiate_sampler, | |||||
conversion_between_reproducible_and_unrepeated_sampler, | |||||
) | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -93,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self._model_device = None | self._model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
@@ -105,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): | |||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `train_step` method, which we can not call actually, we will" | "model also implements the `train_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._train_signature_fn = model.forward | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | "model also implements the `validate_step` method, which we can not call actually, " | ||||
"we will call `forward` function instead of `validate_step` and you should note that.") | "we will call `forward` function instead of `validate_step` and you should note that.") | ||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._validate_signature_fn = model.forward | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `test_step` method, which we can not call actually, we will" | "model also implements the `test_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `test_step` and you should note that.") | " call `forward` function instead of `test_step` and you should note that.") | ||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
@@ -253,7 +262,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | ||||
根据 paddle 设置的环境变量来获得各种属性 | 根据 paddle 设置的环境变量来获得各种属性 | ||||
""" | """ | ||||
print("set_from_env") | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
@@ -267,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
**self._fleet_kwargs | **self._fleet_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | self._configured = True | ||||
@@ -312,67 +320,90 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | if isinstance(dist, ReproducibleSampler): | ||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员 | |||||
# 因此用 type() 进行严格的判断 | |||||
if type(dataloader.batch_sampler) == BatchSampler: | |||||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) | |||||
else: | |||||
shuffle = dataloader.batch_sampler.shuffle | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
batch_sampler.set_distributed( | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
return dataloader | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | ) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
sampler = UnrepeatedRandomSampler( | |||||
dataset=dataloader.dataset, | |||||
shuffle=shuffle, | |||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||||
) | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
dataloader.batch_sampler.sampler = sampler | |||||
return dataloader | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
@@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if driver not in {"paddle", "fleet"}: | if driver not in {"paddle", "fleet"}: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | ||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | ||||
# 优先级 user > cuda | |||||
# 判断单机情况 device 的合法性 | |||||
# 分布式情况下通过 world_device 判断 | |||||
if user_visible_devices != "": | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | |||||
elif cuda_visible_devices is not None: | |||||
_could_use_device_num = len(cuda_visible_devices.split(",")) | |||||
else: | |||||
_could_use_device_num = paddle.device.cuda.device_count() | |||||
if user_visible_devices is None: | |||||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
if device < 0 and device != -1: | if device < 0 and device != -1: | ||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
# if device >= _could_use_device_num: | |||||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
device = f"gpu:{device}" | |||||
if device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
if device != -1: | |||||
device = f"gpu:{device}" | |||||
else: | |||||
device = list(range(_could_use_device_num)) | |||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
for each in device: | for each in device: | ||||
@@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | ||||
elif each < 0: | elif each < 0: | ||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | ||||
elif each >= _could_use_device_num: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||||
" the available gpu number.") | |||||
if len(device) == 1: | if len(device) == 1: | ||||
# 传入了 [1] 这样的,视为单卡。 | # 传入了 [1] 这样的,视为单卡。 | ||||
device = device[0] | device = device[0] | ||||
@@ -1,21 +1,36 @@ | |||||
import os | import os | ||||
import random | import random | ||||
from typing import Union, Optional, Callable, Dict | |||||
from typing import Union, Optional, Dict | |||||
from pathlib import Path | |||||
from functools import partial | from functools import partial | ||||
from dataclasses import dataclass | |||||
import numpy as np | import numpy as np | ||||
from .utils import _build_fp16_env | |||||
from .utils import _build_fp16_env, optimizer_state_to_device | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | ||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_SEED_WORKERS | |||||
from fastNLP.envs import ( | |||||
FASTNLP_SEED_WORKERS, | |||||
FASTNLP_MODEL_FILENAME, | |||||
FASTNLP_CHECKPOINT_FILENAME, | |||||
FASTNLP_GLOBAL_RANK, | |||||
rank_zero_call, | |||||
) | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, IterableDataset | |||||
from paddle.io import ( | |||||
DataLoader, | |||||
IterableDataset, | |||||
Dataset, | |||||
Sampler, | |||||
BatchSampler, | |||||
RandomSampler, | |||||
) | |||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
_reduces = { | _reduces = { | ||||
@@ -41,6 +56,9 @@ class PaddleDriver(Driver): | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
r""" | r""" | ||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | ||||
@@ -69,6 +87,8 @@ class PaddleDriver(Driver): | |||||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | ||||
if isinstance(dataloader.dataset, IterableDataset): | if isinstance(dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") | |||||
else: | else: | ||||
if not isinstance(dataloader, Dict): | if not isinstance(dataloader, Dict): | ||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | ||||
@@ -79,6 +99,9 @@ class PaddleDriver(Driver): | |||||
f"type, not {type(each_dataloader)}.") | f"type, not {type(each_dataloader)}.") | ||||
if isinstance(each_dataloader.dataset, IterableDataset): | if isinstance(each_dataloader.dataset, IterableDataset): | ||||
raise TypeError("`IterableDataset` is not allowed.") | raise TypeError("`IterableDataset` is not allowed.") | ||||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: | |||||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||||
f"`batch_sampler` and `batch_size` should be set.") | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
@@ -153,45 +176,55 @@ class PaddleDriver(Driver): | |||||
getattr(self.model, mode)() | getattr(self.model, mode)() | ||||
@rank_zero_call | @rank_zero_call | ||||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): | |||||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | ||||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | |||||
:param filepath: 保存文件的文件位置(需要包括文件名); | :param filepath: 保存文件的文件位置(需要包括文件名); | ||||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; | |||||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 | |||||
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 | |||||
:param kwargs: | |||||
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 | |||||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||||
的文档: | |||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||||
:return: | |||||
""" | """ | ||||
if model_save_fn is not None: | |||||
model_save_fn(filepath) | |||||
model = self.unwrap_model() | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
paddle.save(states, filepath) | |||||
else: | else: | ||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
paddle.save(model.state_dict(), filepath) | |||||
else: | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") | |||||
paddle.jit.save(model, filepath, input_spec) | |||||
# paddle 在保存整个模型时需要传入额外参数 | |||||
input_spec = kwargs.get("input_spec", None) | |||||
if input_spec is None: | |||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||||
paddle.jit.save(model, filepath, input_spec) | |||||
@staticmethod | |||||
@rank_zero_call | |||||
def load_model(filepath: str, load_dict: bool = True): | |||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
r""" | r""" | ||||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | ||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | :param filepath: 需要被加载的对象的文件位置(需要包括文件名); | ||||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | |||||
即保存了整个模型时,这个参数必须也为False | |||||
:return: 返回加载指定文件后的结果; | |||||
:param only_state_dict: 是否加载state_dict,默认为True。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | """ | ||||
if load_dict: | |||||
return paddle.load(filepath) | |||||
else: | |||||
return paddle.jit.load(filepath) | |||||
model = self.unwrap_model() | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict | |||||
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 | |||||
dirname, filename = os.path.split(filepath) | |||||
if not only_state_dict and dirname == "": | |||||
# 如果传入的是单个文件,则加上相对路径 | |||||
filepath = os.path.join(".", filepath) | |||||
model.load_dict(paddle.load(filepath)) | |||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder, states: Dict): | |||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | r""" | ||||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | ||||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | ||||
@@ -203,48 +236,101 @@ class PaddleDriver(Driver): | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | ||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | ||||
传入的值保持一致。 | 传入的值保持一致。 | ||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||||
:return: | |||||
""" | """ | ||||
# 1. 保存模型的状态; | |||||
model = self.unwrap_model() | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||||
states["model_state_dict"] = model_state_dict | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
# 2. 保存 optimizers 的状态; | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
if only_state_dict: | |||||
logger.debug("Save model state dict.") | |||||
else: | |||||
logger.debug("Save model.") | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: Optimizer = self.optimizers[i] | optimizer: Optimizer = self.optimizers[i] | ||||
optimizer_state = optimizer.state_dict() | optimizer_state = optimizer.state_dict() | ||||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | ||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, folder) | |||||
def load(self, filepath) -> Dict: | |||||
r""" | |||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; | |||||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 | |||||
因此 save 函数和 load 函数的接受和返回值应该是对应的; | |||||
logger.debug("Save optimizer state dict.") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
该函数需要在所有 rank 上执行。 | |||||
:param filepath: 保存断点重训的状态的文件名; | |||||
:return: 需要返回 save 函数输入的 states 内容; | |||||
""" | |||||
states = paddle.load(filepath) | |||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
optimizers_state_dict = states["optimizers_state_dict"] | optimizers_state_dict = states["optimizers_state_dict"] | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] | |||||
optimizer: Optimizer = self.optimizers[i] | |||||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | ||||
logger.debug("Load optimizer state dict.") | |||||
# 2. 加载模型状态; | # 2. 加载模型状态; | ||||
model = self.unwrap_model() | |||||
model.load_dict(states["model_state_dict"]) | |||||
if should_load_model: | |||||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) | |||||
if only_state_dict: | |||||
logger.debug("Load model state dict.") | |||||
else: | |||||
logger.debug("Load model.") | |||||
# 3. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
else: | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
self.barrier() | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -282,7 +368,7 @@ class PaddleDriver(Driver): | |||||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | ||||
""" | """ | ||||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | ||||
global_rank = rank if rank is not None else rank_zero_call.rank | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
# TODO gpu | # TODO gpu | ||||
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | ||||
# back out the base seed so we can use all the bits | # back out the base seed so we can use all the bits | ||||
@@ -313,3 +399,64 @@ class PaddleDriver(Driver): | |||||
""" | """ | ||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) | dataloader.batch_sampler.set_epoch(cur_epoch_idx) | ||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# paddle 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
if dataloader.batch_sampler is not None: | |||||
# 不过在 paddle 中,我们限定了 batch_sampler 不能为 None | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# RandomBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
res.sampler = None | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -2,6 +2,7 @@ import os | |||||
from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.utils import ( | from fastNLP.core.utils import ( | ||||
@@ -10,7 +11,12 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleBatchSampler, | |||||
RandomBatchSampler, | |||||
ReproducibleSampler, | |||||
re_instantiate_sampler, | |||||
) | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -22,16 +28,13 @@ __all__ = [ | |||||
] | ] | ||||
class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs): | |||||
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): | |||||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
if device is None: | if device is None: | ||||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | ||||
if isinstance(device, int): | |||||
self.model_device = get_paddle_gpu_str(device) | |||||
else: | |||||
self.model_device = device | |||||
self.model_device = get_paddle_gpu_str(device) | |||||
self.local_rank = 0 | self.local_rank = 0 | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
@@ -93,18 +96,18 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
def setup(self): | def setup(self): | ||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||||
device_id = get_paddle_device_id(self.model_device) | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
device_id = user_visible_devices.split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
paddle.device.set_device("gpu:0") | |||||
self.model.to("gpu:0") | |||||
device = self.model_device | |||||
if device != "cpu": | |||||
device_id = get_paddle_device_id(device) | |||||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
device = get_device_from_visible(device, output_type=str) | |||||
paddle.device.set_device(device) | |||||
self.model.to(device) | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
@@ -118,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
def test_step(self, batch) -> Dict: | def test_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
@@ -133,38 +136,40 @@ class PaddleSingleDriver(PaddleDriver): | |||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | ||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | ||||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||||
:return: 将移动到指定机器上的 batch 对象返回; | :return: 将移动到指定机器上的 batch 对象返回; | ||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | |||||
device = get_device_from_visible(self.data_device) | |||||
return paddle_move_data_to_device(batch, device) | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# 暂时不支持IteratorDataset | |||||
# 暂时不支持iterableDataset | |||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | |||||
"FastNLP does not support `IteratorDataset` now." | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dataloader.batch_sampler = dist | |||||
return dataloader | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
return dataloader | |||||
else: | |||||
# TODO | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
) | |||||
dataloader.batch_sampler = batch_sampler | |||||
return dataloader | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -4,12 +4,14 @@ import struct | |||||
import random | import random | ||||
import inspect | import inspect | ||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | |||||
from contextlib import ExitStack, closing | from contextlib import ExitStack, closing | ||||
from enum import IntEnum | from enum import IntEnum | ||||
from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call | |||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | |||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import nn | from paddle import nn | ||||
from paddle.nn import Layer | from paddle.nn import Layer | ||||
from paddle.io import DataLoader, BatchSampler | |||||
from paddle.io import DataLoader, BatchSampler, Dataset | |||||
from paddle.amp import auto_cast, GradScaler | from paddle.amp import auto_cast, GradScaler | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | from fastNLP.core.utils.dummy_class import DummyClass as Layer | ||||
@@ -85,7 +87,7 @@ class ForwardState(IntEnum): | |||||
TEST = 2 | TEST = 2 | ||||
PREDICT = 3 | PREDICT = 3 | ||||
_MODE_PARAMETER = "_forward_state" | |||||
_MODE_PARAMETER = "forward_state" | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
@@ -151,24 +153,25 @@ class _FleetWrappingModel(Layer): | |||||
def forward(self, batch, **kwargs) -> Dict: | def forward(self, batch, **kwargs) -> Dict: | ||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif _forward_state == ForwardState.PREDICT: | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | raise NotImplementedError("'PREDICT' mode has not been implemented.") | ||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | raise NotImplementedError("You should direct a concrete mode.") | ||||
@@ -205,7 +208,6 @@ class DummyGradScaler: | |||||
def state_dict(self): | def state_dict(self): | ||||
return {} | return {} | ||||
def _build_fp16_env(dummy=False): | def _build_fp16_env(dummy=False): | ||||
if dummy: | if dummy: | ||||
auto_cast = ExitStack | auto_cast = ExitStack | ||||
@@ -255,61 +257,77 @@ def get_host_name_ip(): | |||||
except: | except: | ||||
return None | return None | ||||
def get_device_from_visible(device: Union[str, int]): | |||||
def get_device_from_visible(device: Union[str, int], output_type=int): | |||||
""" | """ | ||||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | ||||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | ||||
:param devices:未转化的设备名 | |||||
:param device: 未转化的设备名 | |||||
:param output_type: 返回值的类型 | |||||
:return: 转化后的设备id | :return: 转化后的设备id | ||||
""" | """ | ||||
if output_type not in [int, str]: | |||||
raise ValueError("Parameter `output_type` should be one of these types: [int, str]") | |||||
if device == "cpu": | if device == "cpu": | ||||
return device | return device | ||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | ||||
idx = get_paddle_device_id(device) | idx = get_paddle_device_id(device) | ||||
if cuda_visible_devices is None or cuda_visible_devices == "": | if cuda_visible_devices is None or cuda_visible_devices == "": | ||||
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | # 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | ||||
return idx | |||||
raise RuntimeError("This situation should not happen, please report us this bug.") | |||||
else: | else: | ||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
idx = user_visible_devices.split(",")[idx] | |||||
else: | |||||
idx = str(idx) | |||||
if user_visible_devices is None: | |||||
raise RuntimeError("This situation cannot happen, please report a bug to us.") | |||||
idx = user_visible_devices.split(",")[idx] | |||||
cuda_visible_devices_list = cuda_visible_devices.split(',') | cuda_visible_devices_list = cuda_visible_devices.split(',') | ||||
assert idx in cuda_visible_devices_list, "Can't find "\ | |||||
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ | |||||
% (idx, cuda_visible_devices) | |||||
if idx not in cuda_visible_devices_list: | |||||
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") | |||||
res = cuda_visible_devices_list.index(idx) | res = cuda_visible_devices_list.index(idx) | ||||
return res | |||||
if output_type == int: | |||||
return res | |||||
else: | |||||
return f"gpu:{res}" | |||||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
# 拿到实例属性; | |||||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | |||||
""" | |||||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 | |||||
考虑了用户自己定制了 DataLoader 的情形。 | |||||
""" | |||||
# 拿到非下划线开头的实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | ||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | ||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | ||||
# 中寻找; | |||||
# 中寻找;VAR_KEYWORD 代表 **kwargs | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | ||||
del init_params["self"] | del init_params["self"] | ||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | ||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
non_default_params = {name for name, p in init_params.items() if | non_default_params = {name for name, p in init_params.items() if | ||||
name in instance_attrs and p.default != instance_attrs[name]} | name in instance_attrs and p.default != instance_attrs[name]} | ||||
# add `dataset` as it might have been replaced with `*args` | # add `dataset` as it might have been replaced with `*args` | ||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
# 收集不是默认值的参数和它的值 | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) | |||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
reconstruct_args.update({ | |||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||||
"persistent_workers": dataloader._persistent_workers, | |||||
}) | |||||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | |||||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | |||||
# 也即它们没有在初始化函数和实例成员中同时出现 | |||||
required_args = { | required_args = { | ||||
p.name | p.name | ||||
for p in init_params.values() | for p in init_params.values() | ||||
@@ -323,12 +341,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | "This would fail as some of the `__init__` arguments are not available as instance attributes. " | ||||
f"The missing attributes are {required_args}. " | f"The missing attributes are {required_args}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
@@ -340,12 +355,28 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | "This would fail as it doesn't expose all its attributes in the `__init__` signature. " | ||||
f"The missing arguments are {missing_kwargs}. " | f"The missing arguments are {missing_kwargs}. " | ||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedBatchSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
def replace_sampler(dataloader, new_sampler): | |||||
""" | |||||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 | |||||
""" | |||||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | |||||
new_batch_sampler.sampler = new_sampler | |||||
return replace_batch_sampler(dataloader, new_batch_sampler) | |||||
def optimizer_state_to_device(state, device): | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, paddle.Tensor): | |||||
new_state[name] = paddle_to(param, device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state |
@@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
@@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | ||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | ||||
@@ -72,6 +72,14 @@ class TorchDriver(Driver): | |||||
p.grad.requires_grad_(False) | p.grad.requires_grad_(False) | ||||
p.grad.zero_() | p.grad.zero_() | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
if is_train: | if is_train: | ||||
@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]): | |||||
device = device.lower() | device = device.lower() | ||||
if device == "cpu": | if device == "cpu": | ||||
raise ValueError("Cannot get device id from `cpu`.") | raise ValueError("Cannot get device id from `cpu`.") | ||||
elif device == "gpu": | |||||
return 0 | |||||
match_res = re.match(r"gpu:\d+", device) | match_res = re.match(r"gpu:\d+", device) | ||||
if not match_res: | if not match_res: | ||||
raise ValueError( | raise ValueError( | ||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', " | |||||
f"not '{device}'" | |||||
) | ) | ||||
device_id = device.split(':', 1)[1] | device_id = device.split(':', 1)[1] | ||||
device_id = int(device_id) | device_id = int(device_id) | ||||
@@ -185,7 +185,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
return user_params | return user_params | ||||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||||
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | |||||
if not is_dataclass(data): | if not is_dataclass(data): | ||||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | ||||
_dict = dict() | _dict = dict() | ||||
@@ -6,7 +6,8 @@ __all__ = [ | |||||
'is_cur_env_distributed', | 'is_cur_env_distributed', | ||||
'get_global_rank', | 'get_global_rank', | ||||
'rank_zero_call', | 'rank_zero_call', | ||||
'all_rank_call' | |||||
'all_rank_call', | |||||
'get_gpu_count' | |||||
] | ] | ||||
@@ -14,5 +15,5 @@ from .env import * | |||||
from .set_env_on_import import set_env_on_import | from .set_env_on_import import set_env_on_import | ||||
from .set_backend import dump_fastnlp_backend | from .set_backend import dump_fastnlp_backend | ||||
from .imports import * | from .imports import * | ||||
from .utils import _module_available | |||||
from .utils import _module_available, get_gpu_count | |||||
from .distributed import * | from .distributed import * |
@@ -5,13 +5,13 @@ | |||||
import os | import os | ||||
import json | import json | ||||
import sys | import sys | ||||
import subprocess | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.envs.imports import SUPPORT_BACKENDS | from fastNLP.envs.imports import SUPPORT_BACKENDS | ||||
from fastNLP.envs.utils import _module_available | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | |||||
def _set_backend(): | def _set_backend(): | ||||
""" | """ | ||||
@@ -56,17 +56,18 @@ def _set_backend(): | |||||
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | ||||
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | # 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | ||||
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | ||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
if user_visible_devices is not None: | |||||
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | # 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | ||||
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | # 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | ||||
# 我们需要从中找到真正使用的设备编号 | # 我们需要从中找到真正使用的设备编号 | ||||
user_visible_devices = user_visible_devices.split(",") | user_visible_devices = user_visible_devices.split(",") | ||||
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | ||||
else: | else: | ||||
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||||
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||||
# 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||||
range(get_gpu_count()) | |||||
))) | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) | |||||
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | ||||
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | ||||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | elif 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
@@ -78,7 +79,9 @@ def _set_backend(): | |||||
else: | else: | ||||
# 没有设置的话限制在单卡上,防止多进程时占用别的卡 | # 没有设置的话限制在单卡上,防止多进程时占用别的卡 | ||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( | |||||
range(get_gpu_count()) | |||||
))) | |||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
@@ -36,8 +36,7 @@ def set_env_on_import_torch(): | |||||
# TODO paddle may need set this | # TODO paddle may need set this | ||||
def set_env_on_import_paddle(): | def set_env_on_import_paddle(): | ||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||||
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
and "PADDLE_RANK_IN_NODE" in os.environ: | and "PADDLE_RANK_IN_NODE" in os.environ: | ||||
# 检测到了分布式环境的环境变量 | # 检测到了分布式环境的环境变量 | ||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | ||||
@@ -3,6 +3,7 @@ from typing import Callable | |||||
import importlib | import importlib | ||||
from pkg_resources import DistributionNotFound | from pkg_resources import DistributionNotFound | ||||
from packaging.version import Version | from packaging.version import Version | ||||
import subprocess | |||||
import pkg_resources | import pkg_resources | ||||
@@ -46,3 +47,15 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: | |||||
if use_base_version: | if use_base_version: | ||||
pkg_version = Version(pkg_version.base_version) | pkg_version = Version(pkg_version.base_version) | ||||
return op(pkg_version, Version(version)) | return op(pkg_version, Version(version)) | ||||
def get_gpu_count(): | |||||
""" | |||||
利用命令行获取gpu数目的函数 | |||||
:return: gpu数目,如果没有显卡设备则为-1 | |||||
""" | |||||
try: | |||||
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) | |||||
# 经分割后还要除去头部和尾部的换行符 | |||||
return len(lines.split(b"\n")) - 2 | |||||
except: | |||||
return -1 |
@@ -0,0 +1,93 @@ | |||||
""" | |||||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||||
看看有没有用pytest执行的机会 | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||||
""" | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.callbacks import Callback | |||||
import paddle | |||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||||
@dataclass | |||||
class MNISTTrainFleetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -1 | |||||
def test_trainer_fleet( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs, | |||||
): | |||||
model = PaddleNormalModel_Classification_1( | |||||
num_labels=MNISTTrainFleetConfig.num_labels, | |||||
feature_dimension=MNISTTrainFleetConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
validate_dataloaders = val_dataloader | |||||
validate_every = MNISTTrainFleetConfig.validate_every | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="logs", | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
driver = "fleet" | |||||
device = [0,2,3] | |||||
# driver = "paddle" | |||||
# device = 2 | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
test_trainer_fleet( | |||||
driver=driver, | |||||
device=device, | |||||
callbacks=callbacks, | |||||
n_epochs=5, | |||||
) |
@@ -0,0 +1,98 @@ | |||||
""" | |||||
这个文件测试用户以python -m paddle.distributed.launch 启动的情况 | |||||
并且自己初始化了 fleet | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | |||||
""" | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.callbacks import Callback | |||||
import paddle | |||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
import paddle.distributed.fleet as fleet | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | |||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||||
@dataclass | |||||
class MNISTTrainFleetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -1 | |||||
def test_trainer_fleet( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs, | |||||
): | |||||
fleet.init(is_collective=True) | |||||
model = PaddleNormalModel_Classification_2( | |||||
num_labels=MNISTTrainFleetConfig.num_labels, | |||||
feature_dimension=MNISTTrainFleetConfig.feature_dimension, | |||||
) | |||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||||
model = fleet.distributed_model(model) | |||||
optimizers = fleet.distributed_optimizer(optimizers) | |||||
train_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), | |||||
batch_size=MNISTTrainFleetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
validate_dataloaders = val_dataloader | |||||
validate_every = MNISTTrainFleetConfig.validate_every | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
validate_every=validate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="logs", | |||||
data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}" | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
driver = "fleet" | |||||
device = [0,2,3] | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
test_trainer_fleet( | |||||
driver=driver, | |||||
device=device, | |||||
callbacks=callbacks, | |||||
n_epochs=30, | |||||
) |
@@ -143,7 +143,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
accumulation_steps, | accumulation_steps, | ||||
n_epochs=6, | n_epochs=6, | ||||
): | ): | ||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] | |||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -1,83 +1,103 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
from fastNLP.envs.set_backend import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
set_env_on_import_paddle() | |||||
set_env("paddle") | |||||
import paddle | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | |||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
import paddle | |||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver("torch") | |||||
driver = initialize_paddle_driver("torch", 0, model) | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||||
["cpu", "gpu:0", 0, [1]] | |||||
) | ) | ||||
def test_get_single_device(device): | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
""" | """ | ||||
测试正常情况下初始化PaddleSingleDriver的情况 | 测试正常情况下初始化PaddleSingleDriver的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||||
[0, 1] | |||||
) | ) | ||||
def test_get_single_device_with_visiblde_devices(device): | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["fleet"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_fleet_2(driver, device): | |||||
""" | """ | ||||
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 | |||||
测试 fleet 多卡的初始化情况 | |||||
""" | """ | ||||
# TODO | |||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleSingleDriver) | |||||
assert isinstance(driver, PaddleFleetDriver) | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[[1, 2, 3]] | |||||
[[0, 2, 3], -1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | ) | ||||
def test_get_fleet(device): | |||||
@magic_argv_env_context | |||||
def test_get_fleet(driver, device): | |||||
""" | """ | ||||
测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | |||||
[[1,2,3]] | |||||
("driver", "device"), | |||||
[("fleet", "cpu")] | |||||
) | ) | ||||
def test_get_fleet(device): | |||||
@magic_argv_env_context | |||||
def test_get_fleet_cpu(driver, device): | |||||
""" | """ | ||||
测试 launch 启动 fleet 多卡的初始化情况 | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | """ | ||||
# TODO | |||||
model = PaddleNormalModel_Classification(2, 100) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
def test_device_out_of_range(device): | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | |||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
""" | """ | ||||
pass | |||||
model = PaddleNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_paddle_driver(driver, device, model) |
@@ -1,262 +0,0 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
import paddle | |||||
from paddle.io import Dataset, DataLoader | |||||
class Net(paddle.nn.Layer): | |||||
def __init__(self): | |||||
super(Net, self).__init__() | |||||
self.fc1 = paddle.nn.Linear(784, 64) | |||||
self.fc2 = paddle.nn.Linear(64, 32) | |||||
self.fc3 = paddle.nn.Linear(32, 10) | |||||
self.fc4 = paddle.nn.Linear(10, 10) | |||||
def forward(self, x): | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
x = self.fc4(x) | |||||
return x | |||||
class PaddleDataset(Dataset): | |||||
def __init__(self): | |||||
super(PaddleDataset, self).__init__() | |||||
self.items = [paddle.rand((3, 4)) for i in range(320)] | |||||
def __len__(self): | |||||
return len(self.items) | |||||
def __getitem__(self, idx): | |||||
return self.items[idx] | |||||
class TorchNet(torch.nn.Module): | |||||
def __init__(self): | |||||
super(TorchNet, self).__init__() | |||||
self.torch_fc1 = torch.nn.Linear(10, 10) | |||||
self.torch_softmax = torch.nn.Softmax(0) | |||||
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) | |||||
self.torch_tensor = torch.ones(3, 3) | |||||
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | |||||
class TorchDataset(torch.utils.data.Dataset): | |||||
def __init__(self): | |||||
super(TorchDataset, self).__init__() | |||||
self.items = [torch.ones(3, 4) for i in range(320)] | |||||
def __len__(self): | |||||
return len(self.items) | |||||
def __getitem__(self, idx): | |||||
return self.items[idx] | |||||
class PaddleDriverTestCase(unittest.TestCase): | |||||
""" | |||||
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||||
""" | |||||
def setUp(self): | |||||
model = Net() | |||||
self.driver = PaddleDriver(model) | |||||
def test_check_single_optimizer_legacy(self): | |||||
""" | |||||
测试传入单个optimizer时的表现 | |||||
""" | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
# 传入torch的optimizer时,应该报错ValueError | |||||
with self.assertRaises(ValueError) as cm: | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legacy(self): | |||||
""" | |||||
测试传入optimizer list的表现 | |||||
""" | |||||
optimizers = [ | |||||
paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
] | |||||
with self.assertRaises(ValueError) as cm: | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legacy_in_train(self): | |||||
""" | |||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建torch的dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legacy_in_test(self): | |||||
""" | |||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建torch的dataloader | |||||
train_loader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
TorchDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with self.assertRaises(ValueError) as cm: | |||||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试tensor_to_numeric函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = paddle.to_tensor(3) | |||||
res = PaddleDriver.tensor_to_numeric(tensor) | |||||
self.assertEqual(res, 3) | |||||
tensor = paddle.rand((3, 4)) | |||||
res = PaddleDriver.tensor_to_numeric(tensor) | |||||
self.assertListEqual(res, tensor.tolist()) | |||||
# 张量list | |||||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
res = PaddleDriver.tensor_to_numeric(tensor_list) | |||||
self.assertTrue(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
self.assertListEqual(res, tensor_list) | |||||
# 张量tuple | |||||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||||
res = PaddleDriver.tensor_to_numeric(tensor_tuple) | |||||
self.assertTrue(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
self.assertTupleEqual(res, tensor_tuple) | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = PaddleDriver.tensor_to_numeric(tensor_dict) | |||||
self.assertIsInstance(res, dict) | |||||
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) | |||||
self.assertIsInstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
self.assertListEqual(r, d.tolist()) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
self.assertListEqual(r, d.tolist()) | |||||
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
self.assertTrue(self.driver.model.training) | |||||
self.driver.set_model_mode("eval") | |||||
self.assertFalse(self.driver.model.training) | |||||
# 应该报错 | |||||
with self.assertRaises(AssertionError) as cm: | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0") | |||||
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) | |||||
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
PaddleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
def test_get_dataloader_args(self): | |||||
""" | |||||
测试get_dataloader_args | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleDataset()) | |||||
res = PaddleDriver.get_dataloader_args(dataloader) |
@@ -1,19 +1,19 @@ | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import pytest | import pytest | ||||
from pathlib import Path | |||||
from fastNLP.envs.set_backend import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from fastNLP.core import synchronize_safe_rm | |||||
set_env_on_import_paddle() | |||||
set_env("paddle") | |||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||||
from fastNLP.core import synchronize_safe_rm | |||||
import torch | |||||
############################################################################ | ############################################################################ | ||||
@@ -26,38 +26,116 @@ def generate_random_driver(features, labels): | |||||
""" | """ | ||||
生成driver | 生成driver | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification(labels, features) | |||||
model = PaddleNormalModel_Classification_1(labels, features) | |||||
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | ||||
driver = PaddleSingleDriver(model) | |||||
driver = PaddleSingleDriver(model, device="cpu") | |||||
driver.set_optimizers(opt) | driver.set_optimizers(opt) | ||||
driver.setup() | |||||
return driver | return driver | ||||
@pytest.fixture | @pytest.fixture | ||||
def prepare_test_save_load(): | def prepare_test_save_load(): | ||||
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8) | |||||
dataset = PaddleRandomMaxDataset(320, 10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | return driver1, driver2, dataloader | ||||
def test_save_and_load(prepare_test_save_load): | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_with_randombatchsampler(only_state_dict): | |||||
""" | """ | ||||
测试save和load函数 | |||||
TODO optimizer的state_dict为空,暂时不测试 | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
dataset = PaddleRandomMaxDataset(80, 10) | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
) | |||||
# TODO 断点重训完善后在这里迭代几次 | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
if only_state_dict: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
replaced_loader = states["dataloader"] | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||||
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] | |||||
# 3. 检查 model 的参数是否被正确加载 | |||||
for batch in dataloader: | |||||
res1 = driver1.validate_step(batch) | |||||
res2 = driver2.validate_step(batch) | |||||
driver1.save(path, {}) | |||||
driver2.load(path) | |||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||||
# 4. 检查 batch_idx | |||||
# TODO | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_with_randomsampler(only_state_dict): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
dataset = PaddleRandomMaxDataset(80, 10) | |||||
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(dataset, True) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
# TODO 断点重训完善后在这里迭代几次 | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
if only_state_dict: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
replaced_loader = states["dataloader"] | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 model 的参数是否被正确加载 | |||||
for batch in dataloader: | for batch in dataloader: | ||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
# 4. 检查 batch_idx | |||||
# TODO | |||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
@@ -67,13 +145,14 @@ def test_save_and_load_state_dict(prepare_test_save_load): | |||||
TODO optimizer的state_dict为空,暂时不测试 | TODO optimizer的state_dict为空,暂时不测试 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
path = "dict" | |||||
driver1, driver2, dataloader = prepare_test_save_load | driver1, driver2, dataloader = prepare_test_save_load | ||||
driver1.save_model(path) | driver1.save_model(path) | ||||
driver2.model.load_dict(driver2.load_model(path)) | |||||
driver2.load_model(path) | |||||
for batch in dataloader: | for batch in dataloader: | ||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
@@ -87,19 +166,22 @@ def test_save_and_load_whole_model(prepare_test_save_load): | |||||
TODO optimizer的state_dict为空,暂时不测试 | TODO optimizer的state_dict为空,暂时不测试 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model.pdparams" | |||||
path = "model" | |||||
driver1, driver2, dataloader = prepare_test_save_load | driver1, driver2, dataloader = prepare_test_save_load | ||||
driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) | |||||
driver2.model = driver2.load_model(path, load_dict=False) | |||||
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) | |||||
driver2.load_model(path, only_state_dict=False) | |||||
for batch in dataloader: | for batch in dataloader: | ||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.validate_step(batch) | res1 = driver1.validate_step(batch) | ||||
res2 = driver2.validate_step(batch) | res2 = driver2.validate_step(batch) | ||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
synchronize_safe_rm(path + ".pdiparams") | |||||
synchronize_safe_rm(path + ".pdiparams.info") | |||||
synchronize_safe_rm(path + ".pdmodel") | |||||
class TestSingleDeviceFunction: | class TestSingleDeviceFunction: | ||||
@@ -109,8 +191,8 @@ class TestSingleDeviceFunction: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
model = PaddleNormalModel_Classification(10, 784) | |||||
cls.driver = PaddleSingleDriver(model) | |||||
model = PaddleNormalModel_Classification_1(10, 784) | |||||
cls.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | def test_unwrap_model(self): | ||||
""" | """ | ||||
@@ -125,22 +207,6 @@ class TestSingleDeviceFunction: | |||||
self.driver.check_evaluator_mode("validate") | self.driver.check_evaluator_mode("validate") | ||||
self.driver.check_evaluator_mode("test") | self.driver.check_evaluator_mode("test") | ||||
def test_get_model_device_cpu(self): | |||||
""" | |||||
测试get_model_device | |||||
""" | |||||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu") | |||||
device = self.driver.get_model_device() | |||||
assert device == "cpu", device | |||||
def test_get_model_device_gpu(self): | |||||
""" | |||||
测试get_model_device | |||||
""" | |||||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0") | |||||
device = self.driver.get_model_device() | |||||
assert device == "gpu:0", device | |||||
def test_is_distributed(self): | def test_is_distributed(self): | ||||
assert self.driver.is_distributed() == False | assert self.driver.is_distributed() == False | ||||
@@ -151,18 +217,420 @@ class TestSingleDeviceFunction: | |||||
""" | """ | ||||
self.driver.move_data_to_device(paddle.rand((32, 64))) | self.driver.move_data_to_device(paddle.rand((32, 64))) | ||||
@pytest.mark.parametrize( | |||||
"dist_sampler", | |||||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"reproducible", | |||||
[True, False] | |||||
) | |||||
def test_repalce_sampler(self, dist_sampler, reproducible): | |||||
class TestSetDistReproDataloder: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = PaddleNormalDataset(20) | |||||
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
self.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_set_dist_repro_dataloader_with_reproducible_false(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | """ | ||||
测试replace_sampler函数 | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is self.dataloader | |||||
def test_set_dist_repro_dataloader_with_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | |||||
""" | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == self.dataloader.drop_last | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||||
""" | """ | ||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dist = RandomSampler(self.dataset, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is self.dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | |||||
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||||
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
dataloader = DataLoader( | |||||
self.dataset, | |||||
batch_sampler=batch_sampler | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 | |||||
already_seen_idx = set() | |||||
for idx, batch in replaced_loader: | |||||
already_seen_idx.update(batch) | |||||
if idx >= 1: | |||||
break | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
print(sampler_states["data_idx"]) | |||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
replaced_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||||
class TestPaddleDriverFunctions: | |||||
""" | |||||
使用 PaddleSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
self.driver = PaddleSingleDriver(model, device="cpu") | |||||
def test_check_single_optimizer_legality(self): | |||||
""" | |||||
测试传入单个optimizer时的表现 | |||||
""" | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
# 传入torch的optimizer时,应该报错ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入optimizer list的表现 | |||||
""" | |||||
optimizers = [ | |||||
paddle.optimizer.Adam( | |||||
parameters=self.driver.model.parameters(), | |||||
learning_rate=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
] | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legality_in_train(self): | |||||
""" | |||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建torch的dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legality_in_test(self): | |||||
""" | |||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
} | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | |||||
dataloader = { | |||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建torch的dataloader | |||||
train_loader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试tensor_to_numeric函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = paddle.to_tensor(3) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = paddle.rand((3, 4)) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.training | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.training | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
assert self.driver.model.linear1.weight.place.is_cpu_place() | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||||
assert self.driver.model.linear1.weight.place.is_gpu_place() | |||||
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
PaddleSingleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = PaddleNormalDataset() | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=RandomBatchSampler( | |||||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | |||||
batch_size, | |||||
drop_last, | |||||
) | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, paddle.io.SequenceSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = PaddleNormalDataset() | |||||
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) | |||||
batch_sampler.sampler = RandomSampler(dataset, shuffle) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_sampler=batch_sampler, | |||||
) | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, PaddleNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last |
@@ -1,4 +1,56 @@ | |||||
import unittest | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers.paddle_driver.utils import ( | |||||
get_device_from_visible, | |||||
replace_batch_sampler, | |||||
replace_sampler, | |||||
) | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
import paddle | import paddle | ||||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler | |||||
from paddle.io import DataLoader, BatchSampler | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
@pytest.mark.parametrize( | |||||
("user_visible_devices, cuda_visible_devices, device, output_type, correct"), | |||||
( | |||||
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), | |||||
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), | |||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), | |||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), | |||||
("3,4,5,6", "3,5", 0, int, 0), | |||||
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), | |||||
) | |||||
) | |||||
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices | |||||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices | |||||
res = get_device_from_visible(device, output_type) | |||||
assert res == correct | |||||
def test_replace_batch_sampler(): | |||||
dataset = PaddleNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | |||||
assert len(replaced_loader.dataset) == len(dataset) | |||||
assert replaced_loader.batch_sampler.batch_size == 16 | |||||
def test_replace_sampler(): | |||||
dataset = PaddleNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
print("on_after_backward") | print("on_after_backward") | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
print("on_before_optimizer_step") | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
print("on_before_zero_grad") | print("on_before_zero_grad") | ||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
print("on_validate_begin") | print("on_validate_begin") | ||||