diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index a407a7b7..4362dcce 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -191,6 +191,8 @@ class PaddleDriver(Driver): :return: """ model = self.unwrap_model() + if isinstance(filepath, Path): + filepath = str(filepath) if only_state_dict: states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} paddle.save(states, filepath) @@ -211,6 +213,8 @@ class PaddleDriver(Driver): :return: """ model = self.unwrap_model() + if isinstance(filepath, Path): + filepath = str(filepath) # paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict # 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 dirname, filename = os.path.split(filepath) @@ -274,11 +278,11 @@ class PaddleDriver(Driver): logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict - paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: - states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) # 1. 加载 optimizers 的状态; optimizers_state_dict = states["optimizers_state_dict"] @@ -435,6 +439,16 @@ class PaddleDriver(Driver): res.shuffle = True else: res.shuffle = False + # RandomBatchSampler 的情况 + elif hasattr(dataloader.batch_sampler, "batch_sampler"): + batch_sampler = dataloader.batch_sampler.batch_sampler + res.sampler = batch_sampler.sampler + if hasattr(batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(batch_sampler.sampler, RandomSampler): + res.shuffle = True + else: + res.shuffle = False else: res.sampler = None res.shuffle = False diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 36982b4c..895ec703 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -4,12 +4,14 @@ import struct import random import inspect import numpy as np +from copy import deepcopy from contextlib import ExitStack, closing from enum import IntEnum from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to +from fastNLP.core.samplers import RandomSampler from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler + from paddle.io import DataLoader, BatchSampler, Dataset from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -206,7 +208,6 @@ class DummyGradScaler: def state_dict(self): return {} - def _build_fp16_env(dummy=False): if dummy: auto_cast = ExitStack @@ -260,7 +261,7 @@ def get_device_from_visible(device: Union[str, int], output_type=int): """ 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 - :param devices: 未转化的设备名 + :param device: 未转化的设备名 :param output_type: 返回值的类型 :return: 转化后的设备id """ @@ -365,13 +366,8 @@ def replace_sampler(dataloader, new_sampler): """ 使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 """ - new_batch_sampler = BatchSampler( - dataset=dataloader.batch_sampler.dataset, - sampler=new_sampler, - shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler), - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.batch_sampler.drop_last - ) + new_batch_sampler = deepcopy(dataloader.batch_sampler) + new_batch_sampler.sampler = new_sampler return replace_batch_sampler(dataloader, new_batch_sampler) def optimizer_state_to_device(state, device): diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 3d07766a..b9681121 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,11 +1,10 @@ import os -from numpy import isin os.environ["FASTNLP_BACKEND"] = "paddle" import pytest +from pathlib import Path from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import RandomBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -42,27 +41,101 @@ def prepare_test_save_load(): driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader -@pytest.mark.parametrize("reproducible", [True, False]) -@pytest.mark.parametrize("only_state_dict", [True, False]) -def test_save_and_load(prepare_test_save_load, reproducible, only_state_dict): +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict): """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" - driver1, driver2, dataloader = prepare_test_save_load - dataloader = driver1.set_dist_repro_dataloader(dataloader, "dist", reproducible) - driver1.save(path, {}, dataloader, only_state_dict, should_save_model=True) - driver2.load(path, dataloader, only_state_dict, should_load_model=True) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(80, 10) + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + ) + + # TODO 断点重训完善后在这里迭代几次 + + sampler_states = dataloader.batch_sampler.state_dict() + if only_state_dict: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + replaced_loader = states["dataloader"] + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] + + # 3. 检查 model 的参数是否被正确加载 + for batch in dataloader: + res1 = driver1.validate_step(batch) + res2 = driver2.validate_step(batch) + + assert paddle.equal_all(res1["pred"], res2["pred"]) + + # 4. 检查 batch_idx + # TODO + finally: + synchronize_safe_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(80, 10) + batch_sampler = BatchSampler(dataset=dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(dataset, True) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler + ) + + # TODO 断点重训完善后在这里迭代几次 + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + if only_state_dict: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + replaced_loader = states["dataloader"] + + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: res1 = driver1.validate_step(batch) res2 = driver2.validate_step(batch) assert paddle.equal_all(res1["pred"], res2["pred"]) + + # 4. 检查 batch_idx + # TODO finally: synchronize_safe_rm(path) @@ -144,24 +217,138 @@ class TestSingleDeviceFunction: """ self.driver.move_data_to_device(paddle.rand((32, 64))) - # @pytest.mark.parametrize( - # "dist_sampler", [ - # "dist", - # RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), - # RandomSampler(PaddleRandomMaxDataset(320, 10)) - # ] - # ) - # @pytest.mark.parametrize( - # "reproducible", - # [True, False] - # ) - # def test_set_dist_repro_dataloader(self, dist_sampler, reproducible): - # """ - # 测试set_dist_repro_dataloader函数 - # """ - # dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) - - # res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) + +class TestSetDistReproDataloder: + """ + 专门测试 set_dist_repro_dataloader 函数的类 + """ + def setup_method(self): + self.dataset = PaddleNormalDataset(20) + self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + model = PaddleNormalModel_Classification_1(10, 32) + self.driver = PaddleSingleDriver(model, device="cpu") + + def test_set_dist_repro_dataloader_with_reproducible_false(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 + 当dist为字符串时,此时应该返回原来的 dataloader + """ + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) + + assert replaced_loader is self.dataloader + + def test_set_dist_repro_dataloader_with_reproducible_true(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 + 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler + """ + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == self.dataloader.drop_last + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler + 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler + """ + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler is dist + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dist_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 + 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler + """ + dist = RandomSampler(self.dataset, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is dist + assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): + """ + 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + # 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 + already_seen_idx = set() + for idx, batch in replaced_loader: + already_seen_idx.update(batch) + if idx >= 1: + break + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + print(sampler_states["data_idx"]) + + # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + replaced_loader.batch_sampler.load_state_dict(sampler_states) + else: + replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(replaced_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) + assert len(left_idxes | already_seen_idx) == len(self.dataset) class TestPaddleDriverFunctions: """ @@ -229,7 +416,7 @@ class TestPaddleDriverFunctions: with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) - def test_check_dataloader_legacy_in_test(self): + def test_check_dataloader_legality_in_test(self): """ 测试is_train参数为False时,_check_dataloader_legality函数的表现 """ @@ -372,11 +559,78 @@ class TestPaddleDriverFunctions: dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_sampler_epoch(dataloader, 0) - def test_get_dataloader_args(self): + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args(self, batch_size, shuffle, drop_last): """ - 测试get_dataloader_args + 测试正常情况下 get_dataloader_args 的表现 """ - # 先确保不影响运行 - # TODO:正确性 - dataloader = DataLoader(PaddleNormalDataset()) - res = PaddleSingleDriver.get_dataloader_args(dataloader) \ No newline at end of file + dataloader = DataLoader( + PaddleNormalDataset(), + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + if shuffle: + assert isinstance(res.sampler, paddle.io.RandomSampler) + else: + assert isinstance(res.sampler, paddle.io.SequenceSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 batch_sampler 后 get_dataloader_args 的表现 + """ + dataset = PaddleNormalDataset() + dataloader = DataLoader( + dataset, + batch_sampler=RandomBatchSampler( + BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), + batch_size, + drop_last, + ) + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, RandomBatchSampler) + if shuffle: + assert isinstance(res.sampler, paddle.io.RandomSampler) + else: + assert isinstance(res.sampler, paddle.io.SequenceSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 sampler 后 get_dataloader_args 的表现 + """ + dataset = PaddleNormalDataset() + batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) + batch_sampler.sampler = RandomSampler(dataset, shuffle) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + assert isinstance(res.sampler, RandomSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index b072d09d..690d0fb8 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -1,4 +1,56 @@ -import unittest +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "paddle" + +from fastNLP.core.drivers.paddle_driver.utils import ( + get_device_from_visible, + replace_batch_sampler, + replace_sampler, +) +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler import paddle -from paddle.io import Dataset, DataLoader, DistributedBatchSampler \ No newline at end of file +from paddle.io import DataLoader, BatchSampler + +from tests.helpers.datasets.paddle_data import PaddleNormalDataset + +@pytest.mark.parametrize( + ("user_visible_devices, cuda_visible_devices, device, output_type, correct"), + ( + ("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), + ("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), + ("3,4,5,6", "3,5", 0, int, 0), + ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), + ) +) +def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices + res = get_device_from_visible(device, output_type) + assert res == correct + +def test_replace_batch_sampler(): + dataset = PaddleNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + + replaced_loader = replace_batch_sampler(dataloader, batch_sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.dataset, PaddleNormalDataset) + assert len(replaced_loader.dataset) == len(dataset) + assert replaced_loader.batch_sampler.batch_size == 16 + +def test_replace_sampler(): + dataset = PaddleNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataloader, sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) \ No newline at end of file