@@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class Trainer(TrainerEventTrigger): | class Trainer(TrainerEventTrigger): | ||||
@@ -356,6 +357,7 @@ class Trainer(TrainerEventTrigger): | |||||
optimizers, | optimizers, | ||||
device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
n_epochs: int = 20, | n_epochs: int = 20, | ||||
overfit_batches: int = 0, | |||||
evaluate_dataloaders=None, | evaluate_dataloaders=None, | ||||
batch_step_fn: Optional[Callable] = None, | batch_step_fn: Optional[Callable] = None, | ||||
evaluate_batch_step_fn: Optional[Callable] = None, | evaluate_batch_step_fn: Optional[Callable] = None, | ||||
@@ -469,9 +471,6 @@ class Trainer(TrainerEventTrigger): | |||||
n_batches=n_batches | n_batches=n_batches | ||||
) | ) | ||||
if metrics is None and evaluate_dataloaders is not None: | |||||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
if metrics is not None and evaluate_dataloaders is None: | if metrics is not None and evaluate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | ||||
@@ -495,33 +494,44 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
_dist_sampler = None | _dist_sampler = None | ||||
self.dataloader = self.train_dataloader | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | |||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||||
reproducible=self.callback_manager._need_reproducible_sampler) | |||||
# 进行 overfit 相关的设置; | |||||
if overfit_batches != 0: | |||||
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) | |||||
self.overfit_batches = overfit_batches | |||||
self.evaluator = None | self.evaluator = None | ||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and evaluate_dataloaders is not None: | |||||
check_evaluate_every(evaluate_every) | |||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||||
progress_bar=progress_bar, | |||||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||||
if metrics is not None: | |||||
if overfit_batches != 0: | |||||
logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error " | |||||
"because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.") | |||||
evaluate_dataloaders = self.dataloader | |||||
if evaluate_dataloaders is not None: | |||||
check_evaluate_every(evaluate_every) | |||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||||
progress_bar=progress_bar, | |||||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||||
else: | |||||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
if train_fn is not None and not isinstance(train_fn, str): | if train_fn is not None and not isinstance(train_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | ||||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | ||||
self.train_fn = train_fn | self.train_fn = train_fn | ||||
self.dataloader = self.train_dataloader | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | |||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||||
reproducible=self.callback_manager._need_reproducible_sampler) | |||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | self.evaluate_batch_step_fn = evaluate_batch_step_fn | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
@@ -7,10 +7,13 @@ __all__ = [ | |||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
"prepare_dataloader" | |||||
"prepare_dataloader", | |||||
"OverfitDataLoader" | |||||
] | ] | ||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | ||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
from .prepare_dataloader import prepare_dataloader | |||||
from .prepare_dataloader import prepare_dataloader | |||||
from .utils import OverfitDataLoader |
@@ -1,4 +1,4 @@ | |||||
from typing import Callable, Any, Union | |||||
from typing import Callable, Any, Union, Sequence | |||||
from abc import ABC | from abc import ABC | ||||
import inspect | import inspect | ||||
import ast | import ast | ||||
@@ -6,7 +6,8 @@ import ast | |||||
from ..log import logger | from ..log import logger | ||||
from ..utils.cache_results import get_func_calls, truncate_start_blanks | from ..utils.cache_results import get_func_calls, truncate_start_blanks | ||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | |||||
"indice_collate_wrapper", | |||||
"OverfitDataLoader" | |||||
] | ] | ||||
@@ -111,6 +112,42 @@ class HasLenGetitemType(ABC): | |||||
return NotImplemented | return NotImplemented | ||||
class OverfitDataLoader: | |||||
""" | |||||
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; | |||||
""" | |||||
def __init__(self, dataloader, overfit_batches: int): | |||||
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; | |||||
self.batches = [] | |||||
if isinstance(overfit_batches, int): | |||||
if overfit_batches < 0 and overfit_batches != -1: | |||||
raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means" | |||||
"that you use all the data to check whether it could be overfitted.") | |||||
else: | |||||
raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.") | |||||
if overfit_batches > len(dataloader): | |||||
logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.") | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx < overfit_batches or overfit_batches == -1: | |||||
self.batches.append(batch) | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
def __iter__(self): | |||||
for batch in self.batches: | |||||
yield batch | |||||
def __getattr__(self, item): | |||||
return getattr(self.dataloader, item) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
def demo(*args, **kwargs): | def demo(*args, **kwargs): | ||||
pass | pass | ||||
@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | ||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -92,7 +93,7 @@ class TorchDriver(Driver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
if not isinstance(dataloader, DataLoader): | |||||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | ||||
if len(dataloader) == 0: | if len(dataloader) == 0: | ||||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | ||||
@@ -286,6 +286,9 @@ def test_trainer_specific_params_1( | |||||
assert trainer.driver.non_blocking is False | assert trainer.driver.non_blocking is False | ||||
assert trainer.driver.wo_auto_param_call is True | assert trainer.driver.wo_auto_param_call is True | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | ||||
@@ -332,5 +335,44 @@ def test_trainer_specific_params_2( | |||||
assert _ddp_kwargs.get("broadcast_buffers") is True | assert _ddp_kwargs.get("broadcast_buffers") is True | ||||
assert _ddp_kwargs.get("find_unused_parameters") is True | assert _ddp_kwargs.get("find_unused_parameters") is True | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||||
@magic_argv_env_context | |||||
def test_trainer_w_evaluator_overfit_torch( | |||||
model_and_optimizers: TrainerParameters, | |||||
overfit_batches, | |||||
num_train_batch_per_epoch | |||||
): | |||||
""" | |||||
测试一些特殊的参数是否能够正确地传递; | |||||
""" | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver="torch", | |||||
device=0, | |||||
overfit_batches=overfit_batches, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all", | |||||
evaluate_every=-1, | |||||
torch_kwargs={ | |||||
"non_blocking": False, | |||||
"set_grad_to_none": True | |||||
} | |||||
) | |||||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() |
@@ -361,5 +361,32 @@ def test_torch_wo_auto_param_call( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
# 测试 accumulation_steps; | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||||
@magic_argv_env_context | |||||
def test_trainer_overfit_torch( | |||||
model_and_optimizers: TrainerParameters, | |||||
overfit_batches, | |||||
num_train_batch_per_epoch | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver="torch", | |||||
device=0, | |||||
overfit_batches=overfit_batches, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
output_from_new_proc="all", | |||||
n_epochs=2, | |||||
) | |||||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||