From 655e48de99d7ab71d68c76a5c6659d2264d0bb0c Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 13 May 2022 11:19:31 +0000 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9torch=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E4=B8=AD=E5=A4=9A=E5=8D=A1=E7=9A=84driver=E5=8F=82?= =?UTF-8?q?=E6=95=B0=202.=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84driver.save=20driver.load=E4=B8=BAdriver.save?= =?UTF-8?q?=5Fcheckpoint=20driver.load=5Fcheckpoint=203.=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?lstm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_driver/initialize_torch_driver.py | 3 +- fastNLP/modules/torch/encoder/__init__.py | 5 ++ fastNLP/modules/torch/encoder/lstm.py | 82 +++++++++++++++++++ .../test_checkpoint_callback_torch.py | 28 +++---- .../test_load_best_model_callback_torch.py | 2 +- .../callbacks/test_more_evaluate_callback.py | 6 +- .../_test_distributed_launch_torch_1.py | 2 +- .../_test_distributed_launch_torch_2.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 2 +- .../core/drivers/paddle_driver/test_fleet.py | 12 +-- .../paddle_driver/test_single_device.py | 12 +-- tests/core/drivers/torch_driver/test_ddp.py | 10 +-- .../test_initialize_torch_driver.py | 41 +--------- .../torch_driver/test_single_device.py | 8 +- tests/embeddings/torch/test_char_embedding.py | 5 +- 15 files changed, 136 insertions(+), 84 deletions(-) create mode 100644 fastNLP/modules/torch/encoder/__init__.py create mode 100644 fastNLP/modules/torch/encoder/lstm.py diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 723765d2..f8fe63d8 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -32,7 +32,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) if driver not in {"torch", "fairscale"}: - raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") + raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") _could_use_device_num = torch.cuda.device_count() if isinstance(device, str): @@ -43,6 +43,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] elif device >= _could_use_device_num: + print(device, _could_use_device_num) raise ValueError("The gpu device that parameter `device` specifies is not existed.") else: device = torch.device(f"cuda:{device}") diff --git a/fastNLP/modules/torch/encoder/__init__.py b/fastNLP/modules/torch/encoder/__init__.py new file mode 100644 index 00000000..d893305f --- /dev/null +++ b/fastNLP/modules/torch/encoder/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + "LSTM", +] + +from .lstm import LSTM \ No newline at end of file diff --git a/fastNLP/modules/torch/encoder/lstm.py b/fastNLP/modules/torch/encoder/lstm.py new file mode 100644 index 00000000..bd0d844d --- /dev/null +++ b/fastNLP/modules/torch/encoder/lstm.py @@ -0,0 +1,82 @@ +r"""undocumented +轻量封装的 Pytorch LSTM 模块. +可在 forward 时传入序列的长度, 自动对padding做合适的处理. +""" + +__all__ = [ + "LSTM" +] + +import torch +import torch.nn as nn +import torch.nn.utils.rnn as rnn + + +class LSTM(nn.Module): + r""" + LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 + 为1; 且可以应对DataParallel中LSTM的使用问题。 + """ + + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, + bidirectional=False, bias=True): + r""" + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 + :param num_layers: rnn的层数. Default: 1 + :param dropout: 层间dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + """ + super(LSTM, self).__init__() + self.batch_first = batch_first + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, + dropout=dropout, bidirectional=bidirectional) + self.init_param() + + def init_param(self): + for name, param in self.named_parameters(): + if 'bias' in name: + # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 + param.data.fill_(0) + n = param.size(0) + start, end = n // 4, n // 2 + param.data[start:end].fill_(1) + else: + nn.init.xavier_uniform_(param) + + def forward(self, x, seq_len=None, h0=None, c0=None): + r""" + :param x: [batch, seq_len, input_size] 输入序列 + :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` + :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. + """ + batch_size, max_len, _ = x.size() + if h0 is not None and c0 is not None: + hx = (h0, c0) + else: + hx = None + if seq_len is not None and not isinstance(x, rnn.PackedSequence): + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) + if self.batch_first: + x = x[sort_idx] + else: + x = x[:, sort_idx] + x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) + output, hx = self.lstm(x, hx) # -> [N,L,C] + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + if self.batch_first: + output = output[unsort_idx] + else: + output = output[:, unsort_idx] + hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] + else: + output, hx = self.lstm(x, hx) + return output, hx \ No newline at end of file diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 3105acba..5f7d553f 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -74,7 +74,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context(timeout=100) @@ -121,7 +121,7 @@ def test_model_checkpoint_callback_1( # 检查生成保存模型文件的数量是不是正确的; if version == 0: - if driver == "torch": + if not isinstance(device, list): assert "model-epoch_10" in all_saved_model_paths assert "model-epoch_4-batch_123" in all_saved_model_paths @@ -144,7 +144,7 @@ def test_model_checkpoint_callback_1( pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") - if driver == "torch": + if not isinstance(device, list): assert "model-epoch_9" in all_saved_model_paths assert "model-last" in all_saved_model_paths aLL_topk_folders = [] @@ -206,7 +206,7 @@ def test_model_checkpoint_callback_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("only_state_dict", [True]) @magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_2( @@ -259,7 +259,7 @@ def test_model_checkpoint_callback_2( # 检查生成保存模型文件的数量是不是正确的; all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} - if driver == "torch": + if not isinstance(device, list): assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; @@ -299,7 +299,7 @@ def test_model_checkpoint_callback_2( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context(timeout=100) @@ -347,7 +347,7 @@ def test_trainer_checkpoint_callback_1( # 检查生成保存模型文件的数量是不是正确的; if version == 0: - if driver == "torch": + if not isinstance(device, list): assert "trainer-epoch_7" in all_saved_model_paths assert "trainer-epoch_4-batch_123" in all_saved_model_paths @@ -371,7 +371,7 @@ def test_trainer_checkpoint_callback_1( pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} - if driver == "torch": + if not isinstance(device, list): assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -417,7 +417,7 @@ def test_trainer_checkpoint_callback_1( n_epochs=13, output_from_new_proc="all" ) - trainer.load(folder, only_state_dict=only_state_dict) + trainer.load_checkpoint(folder, only_state_dict=only_state_dict) trainer.run() trainer.driver.barrier() @@ -489,7 +489,7 @@ def test_load_state(model_and_optimizers): callbacks=callbacks, output_from_new_proc="all" ) - trainer.load(folder=epoch_2_path) + trainer.load_checkpoint(folder=epoch_2_path) with Capturing() as output: trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) @@ -503,7 +503,7 @@ def test_load_state(model_and_optimizers): @pytest.mark.torch # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; -@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context @pytest.mark.skip("Skip transformers test for now.") @@ -675,7 +675,7 @@ def test_trainer_checkpoint_callback_2( # 检查生成保存模型文件的数量是不是正确的; if version == 0: - if driver == "torch": + if not isinstance(device, list): assert "trainer-epoch_1-batch_200" in all_saved_model_paths epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] @@ -695,7 +695,7 @@ def test_trainer_checkpoint_callback_2( pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} - if driver == "torch": + if not isinstance(device, list): assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -740,7 +740,7 @@ def test_trainer_checkpoint_callback_2( output_mapping=bert_output_mapping, metrics={"acc": acc}, ) - trainer.load(folder, model_load_fn=model_load_fn) + trainer.load_checkpoint(folder, model_load_fn=model_load_fn) trainer.run() trainer.driver.barrier() diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 7501aabf..c607bb87 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -72,7 +72,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch_ddp", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("save_folder", ['save_models', None]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 9c32c20b..925be172 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -98,7 +98,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context @@ -183,7 +183,7 @@ def test_model_more_evaluate_callback_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context @@ -256,7 +256,7 @@ def test_trainer_checkpoint_callback_1( evaluate_fn='train_step' ) folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) - trainer.load(folder, only_state_dict=only_state_dict) + trainer.load_checkpoint(folder, only_state_dict=only_state_dict) trainer.run() trainer.driver.barrier() diff --git a/tests/core/controllers/_test_distributed_launch_torch_1.py b/tests/core/controllers/_test_distributed_launch_torch_1.py index 60f5e36f..0f607423 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_1.py +++ b/tests/core/controllers/_test_distributed_launch_torch_1.py @@ -85,7 +85,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( ): trainer = Trainer( model=model, - driver="torch_ddp", + driver="torch", device=None, optimizers=optimizers, train_dataloader=train_dataloader, diff --git a/tests/core/controllers/_test_distributed_launch_torch_2.py b/tests/core/controllers/_test_distributed_launch_torch_2.py index 37b22590..650f2782 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_2.py +++ b/tests/core/controllers/_test_distributed_launch_torch_2.py @@ -73,7 +73,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( ): trainer = Trainer( model=model, - driver="torch_ddp", + driver="torch", device=None, optimizers=optimizers, train_dataloader=train_dataloader, diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index e3d90e9b..5b794459 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) +@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) @magic_argv_env_context def test_torch_wo_auto_param_call( driver, diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index ef22ba80..d3bffb9f 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -626,9 +626,9 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = DataLoader( @@ -644,7 +644,7 @@ class TestSaveLoad: rank=self.driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -736,9 +736,9 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) @@ -752,7 +752,7 @@ class TestSaveLoad: self.dataset, batch_sampler=batch_sampler ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index ffcb35e7..e7d6707a 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -615,16 +615,16 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = DataLoader( dataset=dataset, batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -697,9 +697,9 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size @@ -709,7 +709,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): dataset, batch_sampler=batch_sampler ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 89e0a7ae..cb7ed68c 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -648,7 +648,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -663,7 +663,7 @@ class TestSaveLoad: rank=driver2.global_rank, pad=True ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -754,9 +754,9 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) @@ -765,7 +765,7 @@ class TestSaveLoad: rank=driver2.global_rank, pad=True ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 8ec70de1..dc89ad0d 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -37,28 +37,6 @@ def test_get_single_device(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchSingleDriver) - -@pytest.mark.torch -@pytest.mark.parametrize( - "device", - [0, 1] -) -@pytest.mark.parametrize( - "driver", - ["torch_ddp"] -) -@magic_argv_env_context -def test_get_ddp_2(driver, device): - """ - 测试 ddp 多卡的初始化情况,但传入了单个 gpu - """ - - model = TorchNormalModel_Classification_1(64, 10) - driver = initialize_torch_driver(driver, device, model) - - assert isinstance(driver, TorchDDPDriver) - - @pytest.mark.torch @pytest.mark.parametrize( "device", @@ -66,7 +44,7 @@ def test_get_ddp_2(driver, device): ) @pytest.mark.parametrize( "driver", - ["torch", "torch_ddp"] + ["torch"] ) @magic_argv_env_context def test_get_ddp(driver, device): @@ -79,21 +57,6 @@ def test_get_ddp(driver, device): assert isinstance(driver, TorchDDPDriver) - -@pytest.mark.torch -@pytest.mark.parametrize( - ("driver", "device"), - [("torch_ddp", "cpu")] -) -def test_get_ddp_cpu(driver, device): - """ - 测试试图在 cpu 上初始化分布式训练的情况 - """ - model = TorchNormalModel_Classification_1(64, 10) - with pytest.raises(ValueError): - driver = initialize_torch_driver(driver, device, model) - - @pytest.mark.torch @pytest.mark.parametrize( "device", @@ -101,7 +64,7 @@ def test_get_ddp_cpu(driver, device): ) @pytest.mark.parametrize( "driver", - ["torch", "torch_ddp"] + ["torch"] ) def test_device_out_of_range(driver, device): """ diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 086f4251..1fbc9d82 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -595,12 +595,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -664,12 +664,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(dataset, 2, True, False) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 diff --git a/tests/embeddings/torch/test_char_embedding.py b/tests/embeddings/torch/test_char_embedding.py index 81ed757a..8decce3f 100644 --- a/tests/embeddings/torch/test_char_embedding.py +++ b/tests/embeddings/torch/test_char_embedding.py @@ -7,8 +7,9 @@ from fastNLP import Vocabulary, DataSet, Instance from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding +@pytest.mark.torch class TestCharEmbed: - @pytest.mark.test + # @pytest.mark.test def test_case_1(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words') @@ -18,7 +19,7 @@ class TestCharEmbed: y = embed(x) assert tuple(y.size()) == (2, 3, 3) - @pytest.mark.test + # @pytest.mark.test def test_case_2(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words')