From 929abc395307b4ed835388f52a9810a8f0cd5dd8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 9 Apr 2022 15:28:13 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E4=BA=86=20test=5Flogger?= =?UTF-8?q?.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/envs/set_env_on_import.py | 2 +- .../_test_distributed_launch_torch_1.py | 4 +- .../_test_distributed_launch_torch_2.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 14 +- tests/core/log/test_logger.py | 300 ++++++++++++++++++ tests/core/samplers/test_sampler.py | 7 - 6 files changed, 310 insertions(+), 19 deletions(-) diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index db978bae..773c1e22 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -15,7 +15,7 @@ def remove_local_rank_in_argv(): """ index = -1 for i, v in enumerate(sys.argv): - if v.startswith('--rank='): + if v.startswith('--local_rank='): os.environ['LOCAL_RANK'] = v.split('=')[1] index = i break diff --git a/tests/core/controllers/_test_distributed_launch_torch_1.py b/tests/core/controllers/_test_distributed_launch_torch_1.py index fb37c8d5..56261922 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_1.py +++ b/tests/core/controllers/_test_distributed_launch_torch_1.py @@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te import argparse import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" import sys path = os.path.abspath(__file__) @@ -101,7 +101,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( ) trainer.run() - dist.barrier() + # dist.barrier() if __name__ == "__main__": diff --git a/tests/core/controllers/_test_distributed_launch_torch_2.py b/tests/core/controllers/_test_distributed_launch_torch_2.py index ad42672a..13d88248 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_2.py +++ b/tests/core/controllers/_test_distributed_launch_torch_2.py @@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te import argparse import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" import sys path = os.path.abspath(__file__) diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index f8058fc9..0a280a0c 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -77,15 +77,14 @@ def model_and_optimizers(request): # 测试一下 cpu; @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) -@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) @magic_argv_env_context def test_trainer_torch_without_evaluator( model_and_optimizers: TrainerParameters, driver, device, - callbacks, n_epochs=10, ): + callbacks = [RecordLossCallback(loss_threshold=0.1)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -108,8 +107,7 @@ def test_trainer_torch_without_evaluator( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), -@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), @pytest.mark.parametrize("fp16", [False, True]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -117,11 +115,11 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( model_and_optimizers: TrainerParameters, driver, device, - callbacks, fp16, accumulation_steps, n_epochs=10, ): + callbacks = [RecordLossCallback(loss_threshold=0.1)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -148,7 +146,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( # 测试 accumulation_steps; -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context def test_trainer_torch_without_evaluator_accumulation_steps( @@ -181,7 +179,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) +@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) @pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) @magic_argv_env_context def test_trainer_output_from_new_proc( @@ -244,7 +242,7 @@ def test_trainer_output_from_new_proc( synchronize_safe_rm(path) -@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) +@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 @magic_argv_env_context def test_trainer_on_exception( diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger.py index e69de29b..da9b7b6b 100644 --- a/tests/core/log/test_logger.py +++ b/tests/core/log/test_logger.py @@ -0,0 +1,300 @@ +import os +import tempfile +import datetime +from pathlib import Path +import logging +import re + +from fastNLP.envs.env import FASTNLP_LAUNCH_TIME +from tests.helpers.utils import magic_argv_env_context +from fastNLP.core import synchronize_safe_rm + + +# 测试 TorchDDPDriver; +@magic_argv_env_context +def test_add_file_ddp_1(): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; + + 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; + """ + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + path = Path.cwd() + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath, mode="w") + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + + synchronize_safe_rm(filepath) + dist.barrier() + dist.destroy_process_group() + logger.removeHandler(handler) + + +@magic_argv_env_context +def test_add_file_ddp_2(): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; + """ + + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + + msg = 'some test log msg' + + origin_path = Path.cwd() + try: + path = origin_path.joinpath("not_existed") + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath) + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + dist.barrier() + dist.destroy_process_group() + + +@magic_argv_env_context +def test_add_file_ddp_3(): + """ + path = None; + + 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; + """ + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + handler = logger.add_file() + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") + with open(file, 'r') as f: + line = ''.join([l for l in f]) + + # print(f"\nrank: {driver.get_local_rank()} line, {line}\n") + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + + synchronize_safe_rm(file) + dist.barrier() + dist.destroy_process_group() + logger.removeHandler(handler) + +@magic_argv_env_context +def test_add_file_ddp_4(): + """ + 测试 path 是文件夹; + """ + + import torch + import torch.distributed as dist + + from fastNLP.core.log.logger import logger + from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + + driver = TorchDDPDriver( + model=model, + parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], + output_from_new_proc="all" + ) + driver.setup() + msg = 'some test log msg' + + path = Path.cwd().joinpath("not_existed") + try: + handler = logger.add_file(path) + logger.info(msg) + logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") + + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + h.flush() + dist.barrier() + + file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert msg in line + assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line + pattern = re.compile(msg) + assert len(pattern.findall(line)) == 1 + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + dist.barrier() + dist.destroy_process_group() + + +class TestLogger: + msg = 'some test log msg' + + def test_add_file_1(self): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; + """ + from fastNLP.core.log.logger import logger + + path = Path(tempfile.mkdtemp()) + try: + filepath = path.joinpath('log.txt') + handler = logger.add_file(filepath) + logger.info(self.msg) + with open(filepath, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + def test_add_file_2(self): + """ + 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; + """ + from fastNLP.core.log.logger import logger + + origin_path = Path(tempfile.mkdtemp()) + + try: + path = origin_path.joinpath("not_existed") + path = path.joinpath('log.txt') + handler = logger.add_file(path) + logger.info(self.msg) + with open(path, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(origin_path) + logger.removeHandler(handler) + + def test_add_file_3(self): + """ + 测试 path 是 None; + """ + from fastNLP.core.log.logger import logger + + handler = logger.add_file() + logger.info(self.msg) + + path = Path.cwd() + cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) + for file in path.iterdir(): + if file.name.startswith(cur_datetime): + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + file.unlink() + logger.removeHandler(handler) + + def test_add_file_4(self): + """ + 测试 path 是文件夹; + """ + from fastNLP.core.log.logger import logger + + path = Path(tempfile.mkdtemp()) + try: + handler = logger.add_file(path) + logger.info(self.msg) + + cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) + for file in path.iterdir(): + if file.name.startswith(cur_datetime): + with open(file, 'r') as f: + line = ''.join([l for l in f]) + assert self.msg in line + finally: + synchronize_safe_rm(path) + logger.removeHandler(handler) + + def test_stdout(self, capsys): + from fastNLP.core.log.logger import logger + + handler = logger.set_stdout(stdout="raw") + logger.info(self.msg) + logger.debug('aabbc') + captured = capsys.readouterr() + assert "some test log msg\n" == captured.out + + logger.removeHandler(handler) + diff --git a/tests/core/samplers/test_sampler.py b/tests/core/samplers/test_sampler.py index 61e28dac..63d8e860 100644 --- a/tests/core/samplers/test_sampler.py +++ b/tests/core/samplers/test_sampler.py @@ -10,13 +10,6 @@ from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset - - - - - - - class SamplerTest(unittest.TestCase): def test_sequentialsampler(self): From a376eea776eed523ff0a9118a2cc97d0a3134db8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 10 Apr 2022 12:56:49 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=BA=86=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=B8=BB=E8=A6=81=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=20trainer.save/load=20=E5=92=8C=20driver.save=20=E5=92=8C=20lo?= =?UTF-8?q?ad=20=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../callbacks/load_best_model_callback.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 62 +++++++++++++++++-- .../test_checkpoint_callback_torch.py | 4 +- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index bd6b8e66..b4ef4e62 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): real_monitor=self._real_monitor, res=results) if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): + (monitor_value > self.monitor_value and self.larger_better): self.monitor_value = monitor_value if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 0e1a45e0..96d11761 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler class TorchDriver(Driver): @@ -178,8 +179,28 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) @rank_zero_call - def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): - # 1. 保存模型的状态; + def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 + # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + states['sampler_states'] = sampler.state_dict() + else: + raise RuntimeError( + 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + + # 2. 保存模型的状态; if should_save_model: model = self.unwrap_model() if only_state_dict: @@ -191,7 +212,7 @@ class TorchDriver(Driver): torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) logger.debug("Save model") - # 2. 保存 optimizers 的状态; + # 3. 保存 optimizers 的状态; optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: torch.optim.Optimizer = self.optimizers[i] @@ -203,7 +224,7 @@ class TorchDriver(Driver): states["optimizers_state_dict"] = optimizers_state_dict torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) - def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; @@ -224,6 +245,39 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) logger.debug("Load model.") + # 3. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + + sampler = dataloader_args.sampler + if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): + # 说明这里需要使用 ReproduceSampler 来弄一下了 + if self.is_distributed(): + raise RuntimeError( + "It is not allowed to use single device checkpoint retraining before but ddp now.") + sampler = ReproducibleBatchSampler( + batch_sampler=sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(states['sampler_states']) + + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states def get_evaluate_context(self): diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 759135f0..f7cc6e5f 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context @@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; -@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context def test_trainer_checkpoint_callback_2(