| @@ -32,7 +32,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
| if driver not in {"torch", "fairscale"}: | if driver not in {"torch", "fairscale"}: | ||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | |||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | |||||
| _could_use_device_num = torch.cuda.device_count() | _could_use_device_num = torch.cuda.device_count() | ||||
| if isinstance(device, str): | if isinstance(device, str): | ||||
| @@ -43,6 +43,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
| device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | ||||
| elif device >= _could_use_device_num: | elif device >= _could_use_device_num: | ||||
| print(device, _could_use_device_num) | |||||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
| else: | else: | ||||
| device = torch.device(f"cuda:{device}") | device = torch.device(f"cuda:{device}") | ||||
| @@ -0,0 +1,5 @@ | |||||
| __all__ = [ | |||||
| "LSTM", | |||||
| ] | |||||
| from .lstm import LSTM | |||||
| @@ -0,0 +1,82 @@ | |||||
| r"""undocumented | |||||
| 轻量封装的 Pytorch LSTM 模块. | |||||
| 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||||
| """ | |||||
| __all__ = [ | |||||
| "LSTM" | |||||
| ] | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.utils.rnn as rnn | |||||
| class LSTM(nn.Module): | |||||
| r""" | |||||
| LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||||
| 为1; 且可以应对DataParallel中LSTM的使用问题。 | |||||
| """ | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||||
| bidirectional=False, bias=True): | |||||
| r""" | |||||
| :param input_size: 输入 `x` 的特征维度 | |||||
| :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | |||||
| :param num_layers: rnn的层数. Default: 1 | |||||
| :param dropout: 层间dropout概率. Default: 0 | |||||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
| :(batch, seq, feature). Default: ``False`` | |||||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
| """ | |||||
| super(LSTM, self).__init__() | |||||
| self.batch_first = batch_first | |||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||||
| dropout=dropout, bidirectional=bidirectional) | |||||
| self.init_param() | |||||
| def init_param(self): | |||||
| for name, param in self.named_parameters(): | |||||
| if 'bias' in name: | |||||
| # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 | |||||
| param.data.fill_(0) | |||||
| n = param.size(0) | |||||
| start, end = n // 4, n // 2 | |||||
| param.data[start:end].fill_(1) | |||||
| else: | |||||
| nn.init.xavier_uniform_(param) | |||||
| def forward(self, x, seq_len=None, h0=None, c0=None): | |||||
| r""" | |||||
| :param x: [batch, seq_len, input_size] 输入序列 | |||||
| :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||||
| :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
| :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
| :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
| 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. | |||||
| """ | |||||
| batch_size, max_len, _ = x.size() | |||||
| if h0 is not None and c0 is not None: | |||||
| hx = (h0, c0) | |||||
| else: | |||||
| hx = None | |||||
| if seq_len is not None and not isinstance(x, rnn.PackedSequence): | |||||
| sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
| if self.batch_first: | |||||
| x = x[sort_idx] | |||||
| else: | |||||
| x = x[:, sort_idx] | |||||
| x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) | |||||
| output, hx = self.lstm(x, hx) # -> [N,L,C] | |||||
| output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | |||||
| _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
| if self.batch_first: | |||||
| output = output[unsort_idx] | |||||
| else: | |||||
| output = output[:, unsort_idx] | |||||
| hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] | |||||
| else: | |||||
| output, hx = self.lstm(x, hx) | |||||
| return output, hx | |||||
| @@ -74,7 +74,7 @@ def model_and_optimizers(request): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
| @@ -121,7 +121,7 @@ def test_model_checkpoint_callback_1( | |||||
| # 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
| if version == 0: | if version == 0: | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "model-epoch_10" in all_saved_model_paths | assert "model-epoch_10" in all_saved_model_paths | ||||
| assert "model-epoch_4-batch_123" in all_saved_model_paths | assert "model-epoch_4-batch_123" in all_saved_model_paths | ||||
| @@ -144,7 +144,7 @@ def test_model_checkpoint_callback_1( | |||||
| pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "model-epoch_9" in all_saved_model_paths | assert "model-epoch_9" in all_saved_model_paths | ||||
| assert "model-last" in all_saved_model_paths | assert "model-last" in all_saved_model_paths | ||||
| aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
| @@ -206,7 +206,7 @@ def test_model_checkpoint_callback_1( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("only_state_dict", [True]) | @pytest.mark.parametrize("only_state_dict", [True]) | ||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
| def test_model_checkpoint_callback_2( | def test_model_checkpoint_callback_2( | ||||
| @@ -259,7 +259,7 @@ def test_model_checkpoint_callback_2( | |||||
| # 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
| all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | ||||
| exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | ||||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
| @@ -299,7 +299,7 @@ def test_model_checkpoint_callback_2( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
| @@ -347,7 +347,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| # 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
| if version == 0: | if version == 0: | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "trainer-epoch_7" in all_saved_model_paths | assert "trainer-epoch_7" in all_saved_model_paths | ||||
| assert "trainer-epoch_4-batch_123" in all_saved_model_paths | assert "trainer-epoch_4-batch_123" in all_saved_model_paths | ||||
| @@ -371,7 +371,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
| # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
| aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
| for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
| @@ -417,7 +417,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| n_epochs=13, | n_epochs=13, | ||||
| output_from_new_proc="all" | output_from_new_proc="all" | ||||
| ) | ) | ||||
| trainer.load(folder, only_state_dict=only_state_dict) | |||||
| trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
| trainer.run() | trainer.run() | ||||
| trainer.driver.barrier() | trainer.driver.barrier() | ||||
| @@ -489,7 +489,7 @@ def test_load_state(model_and_optimizers): | |||||
| callbacks=callbacks, | callbacks=callbacks, | ||||
| output_from_new_proc="all" | output_from_new_proc="all" | ||||
| ) | ) | ||||
| trainer.load(folder=epoch_2_path) | |||||
| trainer.load_checkpoint(folder=epoch_2_path) | |||||
| with Capturing() as output: | with Capturing() as output: | ||||
| trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | ||||
| @@ -503,7 +503,7 @@ def test_load_state(model_and_optimizers): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
| @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @pytest.mark.skip("Skip transformers test for now.") | @pytest.mark.skip("Skip transformers test for now.") | ||||
| @@ -675,7 +675,7 @@ def test_trainer_checkpoint_callback_2( | |||||
| # 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
| if version == 0: | if version == 0: | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "trainer-epoch_1-batch_200" in all_saved_model_paths | assert "trainer-epoch_1-batch_200" in all_saved_model_paths | ||||
| epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | ||||
| @@ -695,7 +695,7 @@ def test_trainer_checkpoint_callback_2( | |||||
| pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
| # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
| if driver == "torch": | |||||
| if not isinstance(device, list): | |||||
| assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
| aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
| for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
| @@ -740,7 +740,7 @@ def test_trainer_checkpoint_callback_2( | |||||
| output_mapping=bert_output_mapping, | output_mapping=bert_output_mapping, | ||||
| metrics={"acc": acc}, | metrics={"acc": acc}, | ||||
| ) | ) | ||||
| trainer.load(folder, model_load_fn=model_load_fn) | |||||
| trainer.load_checkpoint(folder, model_load_fn=model_load_fn) | |||||
| trainer.run() | trainer.run() | ||||
| trainer.driver.barrier() | trainer.driver.barrier() | ||||
| @@ -72,7 +72,7 @@ def model_and_optimizers(request): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch_ddp", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("save_folder", ['save_models', None]) | @pytest.mark.parametrize("save_folder", ['save_models', None]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -98,7 +98,7 @@ def model_and_optimizers(request): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -183,7 +183,7 @@ def test_model_more_evaluate_callback_1( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -256,7 +256,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
| ) | ) | ||||
| folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | ||||
| trainer.load(folder, only_state_dict=only_state_dict) | |||||
| trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
| trainer.run() | trainer.run() | ||||
| trainer.driver.barrier() | trainer.driver.barrier() | ||||
| @@ -85,7 +85,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
| ): | ): | ||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model, | model=model, | ||||
| driver="torch_ddp", | |||||
| driver="torch", | |||||
| device=None, | device=None, | ||||
| optimizers=optimizers, | optimizers=optimizers, | ||||
| train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
| @@ -73,7 +73,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
| ): | ): | ||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model, | model=model, | ||||
| driver="torch_ddp", | |||||
| driver="torch", | |||||
| device=None, | device=None, | ||||
| optimizers=optimizers, | optimizers=optimizers, | ||||
| train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
| @@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_torch_wo_auto_param_call( | def test_torch_wo_auto_param_call( | ||||
| driver, | driver, | ||||
| @@ -626,9 +626,9 @@ class TestSaveLoad: | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| if only_state_dict: | if only_state_dict: | ||||
| self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| @@ -644,7 +644,7 @@ class TestSaveLoad: | |||||
| rank=self.driver2.global_rank, | rank=self.driver2.global_rank, | ||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| @@ -736,9 +736,9 @@ class TestSaveLoad: | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| if only_state_dict: | if only_state_dict: | ||||
| self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | ||||
| @@ -752,7 +752,7 @@ class TestSaveLoad: | |||||
| self.dataset, | self.dataset, | ||||
| batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
| ) | ) | ||||
| load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| @@ -615,16 +615,16 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | ||||
| ) | ) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| @@ -697,9 +697,9 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| @@ -709,7 +709,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| dataset, | dataset, | ||||
| batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
| ) | ) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| @@ -648,7 +648,7 @@ class TestSaveLoad: | |||||
| # 保存状态 | # 保存状态 | ||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
| @@ -663,7 +663,7 @@ class TestSaveLoad: | |||||
| rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| @@ -754,9 +754,9 @@ class TestSaveLoad: | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
| @@ -765,7 +765,7 @@ class TestSaveLoad: | |||||
| rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| @@ -37,28 +37,6 @@ def test_get_single_device(driver, device): | |||||
| driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
| assert isinstance(driver, TorchSingleDriver) | assert isinstance(driver, TorchSingleDriver) | ||||
| @pytest.mark.torch | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| [0, 1] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "driver", | |||||
| ["torch_ddp"] | |||||
| ) | |||||
| @magic_argv_env_context | |||||
| def test_get_ddp_2(driver, device): | |||||
| """ | |||||
| 测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| assert isinstance(driver, TorchDDPDriver) | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "device", | "device", | ||||
| @@ -66,7 +44,7 @@ def test_get_ddp_2(driver, device): | |||||
| ) | ) | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "driver", | "driver", | ||||
| ["torch", "torch_ddp"] | |||||
| ["torch"] | |||||
| ) | ) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_get_ddp(driver, device): | def test_get_ddp(driver, device): | ||||
| @@ -79,21 +57,6 @@ def test_get_ddp(driver, device): | |||||
| assert isinstance(driver, TorchDDPDriver) | assert isinstance(driver, TorchDDPDriver) | ||||
| @pytest.mark.torch | |||||
| @pytest.mark.parametrize( | |||||
| ("driver", "device"), | |||||
| [("torch_ddp", "cpu")] | |||||
| ) | |||||
| def test_get_ddp_cpu(driver, device): | |||||
| """ | |||||
| 测试试图在 cpu 上初始化分布式训练的情况 | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| with pytest.raises(ValueError): | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "device", | "device", | ||||
| @@ -101,7 +64,7 @@ def test_get_ddp_cpu(driver, device): | |||||
| ) | ) | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "driver", | "driver", | ||||
| ["torch", "torch_ddp"] | |||||
| ["torch"] | |||||
| ) | ) | ||||
| def test_device_out_of_range(driver, device): | def test_device_out_of_range(driver, device): | ||||
| """ | """ | ||||
| @@ -595,12 +595,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| @@ -664,12 +664,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | ||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| @@ -7,8 +7,9 @@ from fastNLP import Vocabulary, DataSet, Instance | |||||
| from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding | from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding | ||||
| @pytest.mark.torch | |||||
| class TestCharEmbed: | class TestCharEmbed: | ||||
| @pytest.mark.test | |||||
| # @pytest.mark.test | |||||
| def test_case_1(self): | def test_case_1(self): | ||||
| ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ||||
| vocab = Vocabulary().from_dataset(ds, field_name='words') | vocab = Vocabulary().from_dataset(ds, field_name='words') | ||||
| @@ -18,7 +19,7 @@ class TestCharEmbed: | |||||
| y = embed(x) | y = embed(x) | ||||
| assert tuple(y.size()) == (2, 3, 3) | assert tuple(y.size()) == (2, 3, 3) | ||||
| @pytest.mark.test | |||||
| # @pytest.mark.test | |||||
| def test_case_2(self): | def test_case_2(self): | ||||
| ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ||||
| vocab = Vocabulary().from_dataset(ds, field_name='words') | vocab = Vocabulary().from_dataset(ds, field_name='words') | ||||