From 5b54a0cd731e2b49d87506e3dfd2c8a56e497f56 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 9 Apr 2022 14:57:06 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9Trainer=E7=9A=84catch=5FK?= =?UTF-8?q?eyboardInterrupt=E8=A1=8C=E4=B8=BA=EF=BC=8C=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E4=B8=80=E7=9B=B4warning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 73b712c9..a22f81d8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -263,7 +263,7 @@ class Trainer(TrainerEventTrigger): def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, - catch_KeyboardInterrupt=True): + catch_KeyboardInterrupt=None): """ 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint 去保存断点重训的文件; @@ -273,15 +273,17 @@ class Trainer(TrainerEventTrigger): :param resume_from: 从哪个路径下恢复 trainer 的状态 :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 - 行。 + 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) :return: """ - - if self.driver.is_distributed(): - if catch_KeyboardInterrupt: - logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " - "driver. And we are gonna to set it to False.") - catch_KeyboardInterrupt = False + if catch_KeyboardInterrupt is None: + catch_KeyboardInterrupt = not self.driver.is_distributed() + else: + if self.driver.is_distributed(): + if catch_KeyboardInterrupt: + logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " + "driver. And we are gonna to set it to False.") + catch_KeyboardInterrupt = False self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) From 929abc395307b4ed835388f52a9810a8f0cd5dd8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 9 Apr 2022 15:28:13 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E4=BA=86=20test=5Flogger?= =?UTF-8?q?.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/envs/set_env_on_import.py | 2 +- .../_test_distributed_launch_torch_1.py | 4 +- .../_test_distributed_launch_torch_2.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 14 +- tests/core/log/test_logger.py | 300 ++++++++++++++++++ tests/core/samplers/test_sampler.py | 7 - 6 files changed, 310 insertions(+), 19 deletions(-) diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index db978bae..773c1e22 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -15,7 +15,7 @@ def remove_local_rank_in_argv(): """ index = -1 for i, v in enumerate(sys.argv): - if v.startswith('--rank='): + if v.startswith('--local_rank='): os.environ['LOCAL_RANK'] = v.split('=')[1] index = i break diff --git a/tests/core/controllers/_test_distributed_launch_torch_1.py b/tests/core/controllers/_test_distributed_launch_torch_1.py index fb37c8d5..56261922 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_1.py +++ b/tests/core/controllers/_test_distributed_launch_torch_1.py @@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te import argparse import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" import sys path = os.path.abspath(__file__) @@ -101,7 +101,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( ) trainer.run() - dist.barrier() + # dist.barrier() if __name__ == "__main__": diff --git a/tests/core/controllers/_test_distributed_launch_torch_2.py b/tests/core/controllers/_test_distributed_launch_torch_2.py index ad42672a..13d88248 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_2.py +++ b/tests/core/controllers/_test_distributed_launch_torch_2.py @@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te import argparse import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" import sys path = os.path.abspath(__file__) diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index f8058fc9..0a280a0c 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -77,15 +77,14 @@ def model_and_optimizers(request): # 测试一下 cpu; @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) -@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) @magic_argv_env_context def test_trainer_torch_without_evaluator( model_and_optimizers: TrainerParameters, driver, device, - callbacks, n_epochs=10, ): + callbacks = [RecordLossCallback(loss_threshold=0.1)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -108,8 +107,7 @@ def test_trainer_torch_without_evaluator( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), -@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), @pytest.mark.parametrize("fp16", [False, True]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -117,11 +115,11 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( model_and_optimizers: TrainerParameters, driver, device, - callbacks, fp16, accumulation_steps, n_epochs=10, ): + callbacks = [RecordLossCallback(loss_threshold=0.1)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -148,7 +146,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( # 测试 accumulation_steps; -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context def test_trainer_torch_without_evaluator_accumulation_steps( @@ -181,7 +179,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) +@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) @pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) @magic_argv_env_context def test_trainer_output_from_new_proc( @@ -244,7 +242,7 @@ def test_trainer_output_from_new_proc( synchronize_safe_rm(path) -@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) +@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 @magic_argv_env_context def test_trainer_on_exception( diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger.py index e69de29b..da9b7b6b 100644 --- a/tests/core/log/test_logger.py +++ b/tests/core/log/test_logger.py @@ -0,0 +1,300 @@ +import os +import tempfile +import datetime +from pathlib import Path +import logging +import re + +from fastNLP.envs.env import FASTNLP_LAUNCH_TIME +from tests.helpers.utils import magic_argv_env_context +from fastNLP.core import synchronize_safe_rm + + +# 测试 TorchDDPDriver; +@magic_argv_env_context +def test_add_file_ddp_1(): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; + + 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; + """ + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + path = Path.cwd() + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath, mode="w") + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + + synchronize_safe_rm(filepath) + dist.barrier() + dist.destroy_process_group() + logger.removeHandler(handler) + + +@magic_argv_env_context +def test_add_file_ddp_2(): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; + """ + + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + + msg = 'some test log msg' + + origin_path = Path.cwd() + try: + path = origin_path.joinpath("not_existed") + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath) + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + dist.barrier() + dist.destroy_process_group() + + +@magic_argv_env_context +def test_add_file_ddp_3(): + """ + path = None; + + 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; + """ + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + handler = logger.add_file() + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") + with open(file, 'r') as f: + line = ''.join([l for l in f]) + + # print(f"\nrank: {driver.get_local_rank()} line, {line}\n") + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + + synchronize_safe_rm(file) + dist.barrier() + dist.destroy_process_group() + logger.removeHandler(handler) + +@magic_argv_env_context +def test_add_file_ddp_4(): + """ + 测试 path 是文件夹; + """ + + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + path = Path.cwd().joinpath("not_existed") + try: + handler = logger.add_file(path) + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + + file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + dist.barrier() + dist.destroy_process_group() + + +class TestLogger: + msg = 'some test log msg' + + def test_add_file_1(self): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; + """ + from fastNLP.core.log.logger import logger + + path = Path(tempfile.mkdtemp()) + try: + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath) + logger.info(self.msg) + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + def test_add_file_2(self): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; + """ + from fastNLP.core.log.logger import logger + + origin_path = Path(tempfile.mkdtemp()) + + try: + path = origin_path.joinpath("not_existed") + path = path.joinpath('log.txt') + handler = logger.add_file(path) + logger.info(self.msg) + with open(path, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(origin_path) + logger.removeHandler(handler) + + def test_add_file_3(self): + """ + 测试 path 是 None; + """ + from fastNLP.core.log.logger import logger + + handler = logger.add_file() + logger.info(self.msg) + + path = Path.cwd() + cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) + for file in path.iterdir(): + if file.name.startswith(cur_datetime): + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + file.unlink() + logger.removeHandler(handler) + + def test_add_file_4(self): + """ + 测试 path 是文件夹; + """ + from fastNLP.core.log.logger import logger + + path = Path(tempfile.mkdtemp()) + try: + handler = logger.add_file(path) + logger.info(self.msg) + + cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) + for file in path.iterdir(): + if file.name.startswith(cur_datetime): + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + def test_stdout(self, capsys): + from fastNLP.core.log.logger import logger + + handler = logger.set_stdout(stdout="raw") + logger.info(self.msg) + logger.debug('aabbc') + captured = capsys.readouterr() + assert "some test log msg\n" == captured.out + + logger.removeHandler(handler) + diff --git a/tests/core/samplers/test_sampler.py b/tests/core/samplers/test_sampler.py index 61e28dac..63d8e860 100644 --- a/tests/core/samplers/test_sampler.py +++ b/tests/core/samplers/test_sampler.py @@ -10,13 +10,6 @@ from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset - - - - - - - class SamplerTest(unittest.TestCase): def test_sequentialsampler(self): From 8e4abf2aa5b10d059673542bcd532faef5a5d023 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 10 Apr 2022 00:08:19 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=BA=86driver=E7=9A=84r?= =?UTF-8?q?eplace=5Fsampler=E6=9B=BF=E6=8D=A2=E4=B8=BAset=5Fdist=5Frepro?= =?UTF-8?q?=5Fdataloader;=20=E5=90=8C=E6=97=B6=E4=BF=AE=E6=94=B9=20driver.?= =?UTF-8?q?load/driver.save=20=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 6 +- fastNLP/core/controllers/trainer.py | 72 ++++------------- fastNLP/core/drivers/driver.py | 81 +++++++++++-------- fastNLP/core/drivers/jittor_driver/mpi.py | 3 +- .../drivers/jittor_driver/single_device.py | 9 ++- fastNLP/core/drivers/paddle_driver/fleet.py | 13 +-- .../drivers/paddle_driver/single_device.py | 11 +-- fastNLP/core/drivers/torch_driver/ddp.py | 15 ++-- .../drivers/torch_driver/single_device.py | 12 +-- fastNLP/core/samplers/reproducible_sampler.py | 17 ++++ fastNLP/envs/set_env_on_import.py | 2 +- requirements.txt | 2 +- .../core/drivers/paddle_driver/test_fleet.py | 13 +-- .../paddle_driver/test_single_device.py | 2 +- .../test_torch_replace_sampler.py | 22 ++++- 15 files changed, 148 insertions(+), 132 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 44a76c4e..f58a7faf 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -124,11 +124,7 @@ class Evaluator: self.dataloaders = {} for name, dl in dataloaders.items(): # 替换为正确的 sampler - dl = self.driver.replace_sampler( - dataloader=dl, - dist_sampler=self._dist_sampler, - reproducible=False - ) + dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) self.dataloaders[name] = dl self.progress_bar = kwargs.get('progress_bar', 'auto') diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index a22f81d8..9e1ccfbf 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger): self.dataloader = self.train_dataloader self.driver.set_deterministic_dataloader(self.dataloader) - self.dataloader = self.driver.replace_sampler( - dataloader=self.train_dataloader, - dist_sampler=_dist_sampler, - reproducible=self.callback_manager.has_trainer_chechpoint - ) + self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, + reproducible=self.callback_manager.has_trainer_chechpoint) self.set_grad_to_none = kwargs.get("set_grad_to_none", True) self.on_after_trainer_initialized(self.driver) @@ -578,22 +575,6 @@ class Trainer(TrainerEventTrigger): else: states["val_filter_state"] = None - # 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; - # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 - # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; - dataloader_args = self.driver.get_dataloader_args(self.dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): - sampler = dataloader_args.batch_sampler - elif dataloader_args.sampler: - sampler = dataloader_args.sampler - else: - raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - - if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): - states['sampler_states'] = sampler.state_dict() - else: - raise RuntimeError( - 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') if isinstance(folder, str): folder = Path(folder) @@ -601,9 +582,9 @@ class Trainer(TrainerEventTrigger): if not callable(model_save_fn): raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") rank_zero_call(model_save_fn)(folder) - self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs) + self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs) else: - self.driver.save(folder=folder, states=states, + self.driver.save(folder=folder, dataloader=self.dataloader, states=states, only_state_dict=only_state_dict, should_save_model=True, **kwargs) self.driver.barrier() @@ -616,9 +597,6 @@ class Trainer(TrainerEventTrigger): 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; 注意我们目前不支持单卡到多卡的断点重训; - TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训; - 因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的 - sampler,不管其是第一次断点重训,还是之后的加载的重新训练; :param folder: 保存断点重训 states 的文件地址; :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 @@ -627,33 +605,23 @@ class Trainer(TrainerEventTrigger): self.driver.barrier() if isinstance(folder, str): folder = Path(folder) + + dataloader = self.dataloader + if not resume_training: + dataloader = None + if model_load_fn is not None: if not callable(model_load_fn): - raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") + raise ValueError("Parameter `model_save_fn` should be `Callable`.") rank_zero_call(model_load_fn)(folder) - states = self.driver.load(folder=folder, should_load_model=False, **kwargs) + states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) else: - states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs) + states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) if not resume_training: return - # 1. 恢复 sampler 的状态; - dataloader_args = self.driver.get_dataloader_args(self.dataloader) - - sampler = dataloader_args.sampler - if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): - # 说明这里需要使用 ReproduceSampler 来弄一下了 - if self.driver.is_distributed(): - raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.") - sampler = ReproducibleBatchSampler( - batch_sampler=sampler, - batch_size=dataloader_args.batch_size, - drop_last=dataloader_args.drop_last - ) - sampler.load_state_dict(states['sampler_states']) - - self.driver.replace_sampler(self.dataloader, sampler) + self.dataloader = states.pop('dataloader') # 2. validate filter state; if self.evaluator is not None: @@ -668,22 +636,16 @@ class Trainer(TrainerEventTrigger): # 4. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, ReproducibleBatchSampler): - if dataloader_args.drop_last: - self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size - else: - self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ - (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size - # sampler 是 batch_sampler; - else: - self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch + # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 + # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 + self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') # 5. 恢复所有 callback 的状态; self.on_load_checkpoint(states["callback_states"]) self.driver.barrier() - """ 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """ + """ 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """ def train_step(self, batch): with self.driver.auto_cast(): diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index fe263975..4b141761 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -2,7 +2,7 @@ import os import signal import sys from typing import Any, Sequence, List, Optional, Callable, Dict, Union -from abc import ABC +from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path from io import BytesIO @@ -14,7 +14,6 @@ __all__ = [ from fastNLP.core.utils import nullcontext -# todo 航总 check 一下哪一些方法需要 @abstractmethod; class Driver(ABC): r""" 用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; @@ -32,29 +31,33 @@ class Driver(ABC): # self._consensus_file: Optional[Union[str, Path]] = None self._pids: Optional[List[int]] = None + @abstractmethod def setup(self): r""" 该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; 多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; """ - def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False): + def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): r""" - 因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们 - 需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换 - 为 reproducible sampler; - - :param dataloader: 由 trainer 中传入的原始的 dataloader; - :param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler; - 目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None; - evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; - :param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver, - 并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler; - 如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler; - - :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ; - """ - raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.") + 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 + + :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 + :param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader + 切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 + 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 + 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; + 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; + :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 + 可以可以加载。 + :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, + 如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 + dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 + """ + if dist is None and reproducible is False: + return dataloader + raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` " + f"function.") def set_deterministic_dataloader(self, dataloader): r""" @@ -68,7 +71,7 @@ class Driver(ABC): :param cur_epoch_idx: 当前是第几个 epoch; """ - + @abstractmethod def train_step(self, batch): """ 通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; @@ -103,7 +106,7 @@ class Driver(ABC): 因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 我们应当提醒用户这一行为; """ - raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.") + raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") @property def model(self): @@ -234,6 +237,7 @@ class Driver(ABC): """ self.optimizers = optimizers + @abstractmethod def backward(self, loss): """ 实现深度学习中的反向传播过程; @@ -242,12 +246,14 @@ class Driver(ABC): """ raise NotImplementedError("Each specific driver should implemented its own `backward` function.") + @abstractmethod def step(self): r""" 实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数; """ raise NotImplementedError("Each specific driver should implemented its own `step` function.") + @abstractmethod def zero_grad(self, set_to_none: bool = False): r""" 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; @@ -286,6 +292,7 @@ class Driver(ABC): def auto_cast(self, auto_cast): self._auto_cast = auto_cast + @abstractmethod def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): r""" 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; @@ -296,6 +303,7 @@ class Driver(ABC): """ raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") + @abstractmethod def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): r""" 加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 @@ -307,7 +315,8 @@ class Driver(ABC): """ raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") - def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + @abstractmethod + def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): r""" 断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) @@ -317,12 +326,14 @@ class Driver(ABC): :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 传入的值保持一致。 + :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 """ raise NotImplementedError("Each specific driver should implemented its own `save` function.") - def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: + @abstractmethod + def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: r""" 断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; 其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 @@ -331,11 +342,22 @@ class Driver(ABC): :param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME (如果 should_load_model 为True)。 + :param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' + 以及 'batch_idx_in_epoch' 这两个值。 :param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 :param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 找到对应的模型状态,则报错。 - :return: 需要返回 save 函数输入的 states 内容; + :return: 需要返回 save 函数输入的 states 内容 + 'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 + 在保存与当前传入 data sample 数目不一致时报错。 + 'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 + 数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 + '返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 + 由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 + 只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 + 当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); + 当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 """ raise NotImplementedError("Each specific driver should implemented its own `load` function.") @@ -352,6 +374,7 @@ class Driver(ABC): """ raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") + @abstractmethod def set_model_mode(self, mode: str): r""" 设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; @@ -378,6 +401,7 @@ class Driver(ABC): 中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; """ + @abstractmethod def move_data_to_device(self, batch): r""" 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 @@ -399,17 +423,6 @@ class Driver(ABC): 仅在多分布式训练场景中有使用。 """ - @staticmethod - def get_dataloader_args(dataloader): - """ - 用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key: - sampler, batch_sampler, batch_size, drop_last; - - :param dataloader: - :return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用; - """ - raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.") - def is_distributed(self) -> bool: """ 当前的 driver 实例是否是分布式的; diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index b02249f7..596148bc 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver): def test_step(self, batch): return self._test_step(batch) - def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + reproducible: bool = False, sampler_or_batch_sampler=None): pass def backward(self, loss): diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 452fa85c..f39053d3 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver): def is_distributed(self): return False - def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + reproducible: bool = False, sampler_or_batch_sampler=None): # reproducible 的相关功能暂时没有实现 - if isinstance(dist_sampler, ReproducibleBatchSampler): + if isinstance(dist, ReproducibleBatchSampler): raise NotImplementedError dataloader.batch_sampler = dist_sample - if isinstance(dist_sampler, ReproducibleIterator): + if isinstance(dist, ReproducibleIterator): raise NotImplementedError - dataloader.batch_sampler.sampler = dist_sampler + dataloader.batch_sampler.sampler = dist if reproducible: raise NotImplementedError diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index ff80cb9e..abd15bf3 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -316,13 +316,14 @@ class PaddleFleetDriver(PaddleDriver): def test_step(self, batch): return self._test_step(batch) - def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist_sampler, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist_sampler + if isinstance(dist, ReproducibleIterator): + dataloader.batch_sampler.sampler = dist return dataloader # paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 @@ -334,14 +335,14 @@ class PaddleFleetDriver(PaddleDriver): shuffle = dataloader.batch_sampler.shuffle # trainer, evaluator - if dist_sampler is None: + if dist is None: if reproducible: raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " "control.") else: return dataloader # trainer - elif dist_sampler == "dist": + elif dist == "dist": # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): dataloader.batch_sampler.sampler.set_distributed( @@ -364,7 +365,7 @@ class PaddleFleetDriver(PaddleDriver): dataloader.batch_sampler.sampler = sampler return dataloader # evaluator - elif dist_sampler == "unrepeatdist": + elif dist == "unrepeatdist": sampler = UnrepeatedDistributedSampler( dataset=dataloader.dataset, shuffle=shuffle, diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 1dad6d97..0b4d09bb 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -133,15 +133,16 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") - def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持IteratorDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist_sampler, ReproducibleBatchSampler): - dataloader.batch_sampler = dist_sampler + if isinstance(dist, ReproducibleBatchSampler): + dataloader.batch_sampler = dist return dataloader - if isinstance(dist_sampler, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist_sampler + if isinstance(dist, ReproducibleIterator): + dataloader.batch_sampler.sampler = dist return dataloader if reproducible: diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 637b1e67..9b3325d8 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -445,21 +445,22 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): - if isinstance(dist_sampler, ReproducibleIterator): + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + reproducible: bool = False, sampler_or_batch_sampler=None): + if isinstance(dist, ReproducibleIterator): # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; - dist_sampler = re_instantiate_sampler(dist_sampler) - return replace_sampler(dataloader, dist_sampler) + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) # trainer, evaluator - if dist_sampler is None: + if dist is None: if reproducible: raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: return dataloader # trainer - elif dist_sampler == "dist": + elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; if isinstance(args.sampler, ReproducibleIterator): @@ -485,7 +486,7 @@ class TorchDDPDriver(TorchDriver): return replace_sampler(dataloader, sampler) # evaluator - elif dist_sampler == "unrepeatdist": + elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) sampler = UnrepeatedDistributedSampler( dataset=args.dataset, diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index b4ce0ecf..034292eb 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], - reproducible: bool = False): - if isinstance(dist_sampler, ReproducibleBatchSampler): - return replace_batch_sampler(dataloader, dist_sampler) - elif isinstance(dist_sampler, ReproducibleIterator): - return replace_sampler(dataloader, dist_sampler) + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + reproducible: bool = False, sampler_or_batch_sampler=None): + if isinstance(dist, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleIterator): + return replace_sampler(dataloader, dist) if reproducible: args = self.get_dataloader_args(dataloader) diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index e0211790..1382282a 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -50,6 +50,14 @@ class ReproducibleIterator: class RandomSampler(ReproducibleIterator): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): + """ + + + :param dataset: 实现了 __len__ 方法的数据容器 + :param shuffle: 是否在每次 iterate 的时候打乱顺序。 + :param seed: 随机数种子。 + :param kwargs: 用户不需要使用,fastNLP 内部使用 + """ self.dataset = dataset self.shuffle = shuffle @@ -208,6 +216,15 @@ class RandomSampler(ReproducibleIterator): class ReproducibleBatchSampler: # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): + """ + 可以使得 batch_sampler 对象状态恢复的 wrapper 。 + + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 + 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 + :param batch_size: 每个 batch 的大小是多少。 + :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 + :param kwargs: fastNLP 内部使用。 + """ self.batch_sampler = batch_sampler self.batch_size = batch_size self.drop_last = drop_last diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index db978bae..773c1e22 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -15,7 +15,7 @@ def remove_local_rank_in_argv(): """ index = -1 for i, v in enumerate(sys.argv): - if v.startswith('--rank='): + if v.startswith('--local_rank='): os.environ['LOCAL_RANK'] = v.split('=')[1] index = i break diff --git a/requirements.txt b/requirements.txt index 2e2808d1..ce82c20d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ prettytable>=0.7.2 requests regex!=2019.12.17 rich==11.2.0 -# fsspec[http]>=2021.05.0, !=2021.06.0 \ No newline at end of file +packaging \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index aea4ca40..e20866b3 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,12 +1,9 @@ import pytest -import sys import os import numpy as np -from fastNLP.envs.set_backend import set_env from fastNLP.envs.set_env_on_import import set_env_on_import_paddle set_env_on_import_paddle() -set_env("paddle") import paddle import paddle.distributed as dist from paddle.io import DataLoader @@ -54,6 +51,7 @@ def test_move_data_to_device(): dist.barrier() + @magic_argv_env_context def test_is_distributed(): print(os.getenv("CUDA_VISIBLE_DEVICES")) @@ -64,6 +62,7 @@ def test_is_distributed(): driver = PaddleFleetDriver( model=paddle_model, parallel_device=[0,1], + output_from_new_proc='all' ) driver.set_optimizers(paddle_opt) # 区分launch和子进程setup的时候 @@ -79,6 +78,7 @@ def test_is_distributed(): synchronize_safe_rm("log") dist.barrier() + @magic_argv_env_context def test_get_no_sync_context(): """ @@ -105,6 +105,7 @@ def test_get_no_sync_context(): synchronize_safe_rm("log") dist.barrier() + @magic_argv_env_context def test_is_global_zero(): try: @@ -128,6 +129,8 @@ def test_is_global_zero(): synchronize_safe_rm("log") dist.barrier() + + @magic_argv_env_context def test_unwrap_model(): try: @@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible): else: driver.setup() dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) - driver.replace_sampler(dataloader, dist_sampler, reproducible) + driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) finally: synchronize_safe_rm("log") dist.barrier() @@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase: parallel_device=gpus, ) driver.set_optimizers(paddle_opt) - dataloader = driver.replace_sampler(dataloader) + dataloader = driver.set_dist_repro_dataloader(dataloader, ) driver.setup() # 检查model_device self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 4c9ff5f8..2cb6d5be 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -164,4 +164,4 @@ class TestSingleDeviceFunction: """ dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) - res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible) \ No newline at end of file + res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) \ No newline at end of file diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index edb98190..81d693fc 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -33,11 +33,15 @@ def check_replace_sampler(driver): # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler # reproducible 是 True 和 False + # 需要 check 返回的 sampler 和 dataloader 都不同了 assert driver.is_distributed() is False, "This test only for non distributed sampler." ds = SequenceDataSet(10) dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) - dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True) + dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True) + + assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one." + assert not (dl1 is dataloader), "The dataloader should not the same one." # 迭代两个 batch already_seen_idx = set() @@ -68,6 +72,22 @@ def check_replace_sampler(driver): assert b not in already_seen_idx assert b in left_idxes + # 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad;(2)所有卡互相不重复 + ds = SequenceDataSet(11) + dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) + dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True) + world_size = 3 + indices = [] + for i in range(world_size): + dl1.sampler.set_distributed(num_replicas=world_size, rank=i) + for idx, batch in dl1: + indices.extend(batch) + assert len(indices)==len(ds) # 应该没有任何重复 + assert len(set(indices))==len(indices) # 应该全是不一样的indice + + + + From a376eea776eed523ff0a9118a2cc97d0a3134db8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 10 Apr 2022 12:56:49 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=BA=86=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=B8=BB=E8=A6=81=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=20trainer.save/load=20=E5=92=8C=20driver.save=20=E5=92=8C=20lo?= =?UTF-8?q?ad=20=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../callbacks/load_best_model_callback.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 62 +++++++++++++++++-- .../test_checkpoint_callback_torch.py | 4 +- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index bd6b8e66..b4ef4e62 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): real_monitor=self._real_monitor, res=results) if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): + (monitor_value > self.monitor_value and self.larger_better): self.monitor_value = monitor_value if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 0e1a45e0..96d11761 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler class TorchDriver(Driver): @@ -178,8 +179,28 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) @rank_zero_call - def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): - # 1. 保存模型的状态; + def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 + # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + states['sampler_states'] = sampler.state_dict() + else: + raise RuntimeError( + 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + + # 2. 保存模型的状态; if should_save_model: model = self.unwrap_model() if only_state_dict: @@ -191,7 +212,7 @@ class TorchDriver(Driver): torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) logger.debug("Save model") - # 2. 保存 optimizers 的状态; + # 3. 保存 optimizers 的状态; optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: torch.optim.Optimizer = self.optimizers[i] @@ -203,7 +224,7 @@ class TorchDriver(Driver): states["optimizers_state_dict"] = optimizers_state_dict torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) - def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; @@ -224,6 +245,39 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) logger.debug("Load model.") + # 3. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + + sampler = dataloader_args.sampler + if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): + # 说明这里需要使用 ReproduceSampler 来弄一下了 + if self.is_distributed(): + raise RuntimeError( + "It is not allowed to use single device checkpoint retraining before but ddp now.") + sampler = ReproducibleBatchSampler( + batch_sampler=sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(states['sampler_states']) + + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states def get_evaluate_context(self): diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 759135f0..f7cc6e5f 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context @@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; -@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context def test_trainer_checkpoint_callback_2(