From 5ea4f75ff873a7c845c649f5d9046aba4bcd81eb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 06:54:31 +0000 Subject: [PATCH 01/26] =?UTF-8?q?paddle=20=E7=8E=AF=E5=A2=83=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/envs/set_backend.py | 5 ++--- fastNLP/envs/set_env_on_import.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 68a28335..18cc970e 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -8,7 +8,7 @@ import sys from collections import defaultdict -from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED from fastNLP.envs.imports import SUPPORT_BACKENDS from fastNLP.envs.utils import _module_available @@ -65,8 +65,7 @@ def _set_backend(): else: # 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 os.environ[USER_CUDA_VISIBLE_DEVICES] = "" - # TODO 这里的 [0] 可能在单个节点多卡的时候有问题 - os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] + os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) elif 'CUDA_VISIBLE_DEVICES' in os.environ: diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index db978bae..1ca49289 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -36,8 +36,7 @@ def set_env_on_import_torch(): # TODO paddle may need set this def set_env_on_import_paddle(): - # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS - if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ + if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ and "PADDLE_RANK_IN_NODE" in os.environ: # 检测到了分布式环境的环境变量 os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] From 791580797c6ba7d9a2aa5938776167e7226d7e17 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 06:55:05 +0000 Subject: [PATCH 02/26] =?UTF-8?q?paddle=20=E5=88=86=E5=B8=83=E5=BC=8F?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/controllers/test_trainer_fleet.py | 93 ++++++++++++++++++ .../controllers/test_trainer_fleet_outside.py | 98 +++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 tests/core/controllers/test_trainer_fleet.py create mode 100644 tests/core/controllers/test_trainer_fleet_outside.py diff --git a/tests/core/controllers/test_trainer_fleet.py b/tests/core/controllers/test_trainer_fleet.py new file mode 100644 index 00000000..a294ad1f --- /dev/null +++ b/tests/core/controllers/test_trainer_fleet.py @@ -0,0 +1,93 @@ +""" +这个文件测试用户以python -m paddle.distributed.launch 启动的情况 +看看有没有用pytest执行的机会 +python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py +""" +import os +os.environ["FASTNLP_BACKEND"] = "paddle" +import sys +sys.path.append("../../../") + +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.callbacks import Callback + +import paddle +from paddle.optimizer import Adam +from paddle.io import DataLoader + +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback + +@dataclass +class MNISTTrainFleetConfig: + num_labels: int = 10 + feature_dimension: int = 10 + + batch_size: int = 32 + shuffle: bool = True + validate_every = -1 + +def test_trainer_fleet( + driver, + device, + callbacks, + n_epochs, +): + model = PaddleNormalModel_Classification_1( + num_labels=MNISTTrainFleetConfig.num_labels, + feature_dimension=MNISTTrainFleetConfig.feature_dimension + ) + optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) + + train_dataloader = DataLoader( + dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), + batch_size=MNISTTrainFleetConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), + batch_size=MNISTTrainFleetConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + validate_dataloaders = val_dataloader + validate_every = MNISTTrainFleetConfig.validate_every + metrics = {"acc": Accuracy()} + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + validate_dataloaders=validate_dataloaders, + validate_every=validate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + output_from_new_proc="logs", + ) + trainer.run() + +if __name__ == "__main__": + driver = "fleet" + device = [0,2,3] + # driver = "paddle" + # device = 2 + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + test_trainer_fleet( + driver=driver, + device=device, + callbacks=callbacks, + n_epochs=5, + ) \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_fleet_outside.py b/tests/core/controllers/test_trainer_fleet_outside.py new file mode 100644 index 00000000..d461e211 --- /dev/null +++ b/tests/core/controllers/test_trainer_fleet_outside.py @@ -0,0 +1,98 @@ +""" +这个文件测试用户以python -m paddle.distributed.launch 启动的情况 +并且自己初始化了 fleet +python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py +""" +import os +os.environ["FASTNLP_BACKEND"] = "paddle" +import sys +sys.path.append("../../../") + +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.callbacks import Callback + +import paddle +from paddle.optimizer import Adam +from paddle.io import DataLoader +import paddle.distributed.fleet as fleet + +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 +from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback + +@dataclass +class MNISTTrainFleetConfig: + num_labels: int = 10 + feature_dimension: int = 10 + + batch_size: int = 32 + shuffle: bool = True + validate_every = -1 + +def test_trainer_fleet( + driver, + device, + callbacks, + n_epochs, +): + fleet.init(is_collective=True) + + model = PaddleNormalModel_Classification_2( + num_labels=MNISTTrainFleetConfig.num_labels, + feature_dimension=MNISTTrainFleetConfig.feature_dimension, + ) + optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) + + model = fleet.distributed_model(model) + optimizers = fleet.distributed_optimizer(optimizers) + + train_dataloader = DataLoader( + dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension), + batch_size=MNISTTrainFleetConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension), + batch_size=MNISTTrainFleetConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + validate_dataloaders = val_dataloader + validate_every = MNISTTrainFleetConfig.validate_every + metrics = {"acc": Accuracy()} + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + validate_dataloaders=validate_dataloaders, + validate_every=validate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + output_from_new_proc="logs", + data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}" + ) + trainer.run() + +if __name__ == "__main__": + driver = "fleet" + device = [0,2,3] + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + test_trainer_fleet( + driver=driver, + device=device, + callbacks=callbacks, + n_epochs=30, + ) \ No newline at end of file From e3d565b6390ecd25a1b6610ea21bbffdd9a50481 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 14:18:53 +0000 Subject: [PATCH 03/26] small --- fastNLP/core/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 66159f24..73267e7f 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -181,7 +181,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): return user_params -def dataclass_to_dict(data: "dataclass") -> Dict: +def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: if not is_dataclass(data): raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") _dict = dict() From 193c04c9e28fde2566a7bc5b6525f32d90b4dde9 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 14:57:55 +0000 Subject: [PATCH 04/26] =?UTF-8?q?initialize=5Fpaddle=5Fdriver=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paddle_driver/initialize_paddle_driver.py | 25 +++-- .../test_initialize_paddle_driver.py | 96 +++++++++++-------- 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index db30517f..98655757 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ if driver not in {"paddle", "fleet"}: raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") - # 优先级 user > cuda - # 判断单机情况 device 的合法性 - # 分布式情况下通过 world_device 判断 - if user_visible_devices != "": - _could_use_device_num = len(user_visible_devices.split(",")) - elif cuda_visible_devices is not None: - _could_use_device_num = len(cuda_visible_devices.split(",")) - else: - _could_use_device_num = paddle.device.cuda.device_count() + if user_visible_devices is None: + raise RuntimeError("This situation cannot happen, please report a bug to us.") + _could_use_device_num = len(user_visible_devices.split(",")) if isinstance(device, int): if device < 0 and device != -1: raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - # if device >= _could_use_device_num: - # raise ValueError("The gpu device that parameter `device` specifies is not existed.") - device = f"gpu:{device}" + if device >= _could_use_device_num: + raise ValueError("The gpu device that parameter `device` specifies is not existed.") + if device != -1: + device = f"gpu:{device}" + else: + device = list(range(_could_use_device_num)) elif isinstance(device, Sequence) and not isinstance(device, str): device = list(set(device)) for each in device: @@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") elif each < 0: raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") + elif each >= _could_use_device_num: + raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" + " the available gpu number.") if len(device) == 1: # 传入了 [1] 这样的,视为单卡。 device = device[0] diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index 30d5ef3c..54ef22b6 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -1,83 +1,103 @@ +import os import pytest -from fastNLP.envs.set_backend import set_env -from fastNLP.envs.set_env_on_import import set_env_on_import_paddle - -set_env_on_import_paddle() -set_env("paddle") -import paddle +os.environ["FASTNLP_BACKEND"] = "paddle" +from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver -from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification +from fastNLP.envs import get_gpu_count +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.utils import magic_argv_env_context +import paddle def test_incorrect_driver(): + model = PaddleNormalModel_Classification_1(2, 100) with pytest.raises(ValueError): - driver = initialize_paddle_driver("torch") + driver = initialize_paddle_driver("torch", 0, model) @pytest.mark.parametrize( "device", - ["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] + ["cpu", "gpu:0", 0, [1]] ) -def test_get_single_device(device): +@pytest.mark.parametrize( + "driver", + ["paddle"] +) +def test_get_single_device(driver, device): """ 测试正常情况下初始化PaddleSingleDriver的情况 """ - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) - + model = PaddleNormalModel_Classification_1(2, 100) + driver = initialize_paddle_driver(driver, device, model) assert isinstance(driver, PaddleSingleDriver) @pytest.mark.parametrize( "device", - ["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] + [0, 1] ) -def test_get_single_device_with_visiblde_devices(device): +@pytest.mark.parametrize( + "driver", + ["fleet"] +) +@magic_argv_env_context +def test_get_fleet_2(driver, device): """ - 测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 + 测试 fleet 多卡的初始化情况 """ - # TODO - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) + model = PaddleNormalModel_Classification_1(64, 10) + driver = initialize_paddle_driver(driver, device, model) - assert isinstance(driver, PaddleSingleDriver) + assert isinstance(driver, PaddleFleetDriver) @pytest.mark.parametrize( "device", - [[1, 2, 3]] + [[0, 2, 3], -1] +) +@pytest.mark.parametrize( + "driver", + ["paddle", "fleet"] ) -def test_get_fleet(device): +@magic_argv_env_context +def test_get_fleet(driver, device): """ 测试 fleet 多卡的初始化情况 """ - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) + model = PaddleNormalModel_Classification_1(64, 10) + driver = initialize_paddle_driver(driver, device, model) assert isinstance(driver, PaddleFleetDriver) @pytest.mark.parametrize( - "device", - [[1,2,3]] + ("driver", "device"), + [("fleet", "cpu")] ) -def test_get_fleet(device): +@magic_argv_env_context +def test_get_fleet_cpu(driver, device): """ - 测试 launch 启动 fleet 多卡的初始化情况 + 测试试图在 cpu 上初始化分布式训练的情况 """ - # TODO - - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) - - assert isinstance(driver, PaddleFleetDriver) + model = PaddleNormalModel_Classification_1(64, 10) + with pytest.raises(ValueError): + driver = initialize_paddle_driver(driver, device, model) -def test_device_out_of_range(device): +@pytest.mark.parametrize( + "device", + [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] +) +@pytest.mark.parametrize( + "driver", + ["paddle", "fleet"] +) +@magic_argv_env_context +def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 """ - pass \ No newline at end of file + model = PaddleNormalModel_Classification_1(2, 100) + with pytest.raises(ValueError): + driver = initialize_paddle_driver(driver, device, model) \ No newline at end of file From da849564d60f457e8c0bbe36b92d56e23408c4e2 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 14:59:05 +0000 Subject: [PATCH 05/26] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=88=A9=E7=94=A8?= =?UTF-8?q?=E5=91=BD=E4=BB=A4=E8=8E=B7=E5=8F=96gpu=E6=95=B0=E7=9B=AE?= =?UTF-8?q?=E7=9A=84=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/envs/__init__.py | 5 +++-- fastNLP/envs/set_backend.py | 18 +++++++++++------- fastNLP/envs/utils.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/fastNLP/envs/__init__.py b/fastNLP/envs/__init__.py index 524572b3..4ae30677 100644 --- a/fastNLP/envs/__init__.py +++ b/fastNLP/envs/__init__.py @@ -6,7 +6,8 @@ __all__ = [ 'is_cur_env_distributed', 'get_global_rank', 'rank_zero_call', - 'all_rank_call' + 'all_rank_call', + 'get_gpu_count' ] @@ -14,5 +15,5 @@ from .env import * from .set_env_on_import import set_env_on_import from .set_backend import dump_fastnlp_backend from .imports import * -from .utils import _module_available +from .utils import _module_available, get_gpu_count from .distributed import * diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 18cc970e..a9e82c74 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -5,13 +5,13 @@ import os import json import sys +import subprocess from collections import defaultdict -from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED from fastNLP.envs.imports import SUPPORT_BACKENDS -from fastNLP.envs.utils import _module_available - +from fastNLP.envs.utils import _module_available, get_gpu_count def _set_backend(): """ @@ -56,15 +56,17 @@ def _set_backend(): if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: # 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') - if user_visible_devices is not None and user_visible_devices != "": + if user_visible_devices is not None: # 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 # 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 # 我们需要从中找到真正使用的设备编号 user_visible_devices = user_visible_devices.split(",") selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) else: - # 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 - os.environ[USER_CUDA_VISIBLE_DEVICES] = "" + # 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备 + os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( + range(get_gpu_count()) + ))) os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) @@ -77,7 +79,9 @@ def _set_backend(): else: # 没有设置的话限制在单卡上,防止多进程时占用别的卡 os.environ['CUDA_VISIBLE_DEVICES'] = '0' - os.environ[USER_CUDA_VISIBLE_DEVICES] = '' + os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list( + range(get_gpu_count()) + ))) elif backend == 'jittor': assert _module_available(backend), f"You must have {backend} available to use {backend} backend." diff --git a/fastNLP/envs/utils.py b/fastNLP/envs/utils.py index b06ba615..355c2448 100644 --- a/fastNLP/envs/utils.py +++ b/fastNLP/envs/utils.py @@ -3,6 +3,7 @@ from typing import Callable import importlib from pkg_resources import DistributionNotFound from packaging.version import Version +import subprocess import pkg_resources @@ -46,3 +47,15 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: if use_base_version: pkg_version = Version(pkg_version.base_version) return op(pkg_version, Version(version)) + +def get_gpu_count(): + """ + 利用命令行获取gpu数目的函数 + :return: gpu数目,如果没有显卡设备则为-1 + """ + try: + lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) + # 经分割后还要除去头部和尾部的换行符 + return len(lines.split(b"\n")) - 2 + except: + return -1 \ No newline at end of file From 9678c559c99b68d572c228bad821031b5389bf31 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 14:59:45 +0000 Subject: [PATCH 06/26] =?UTF-8?q?=E8=B7=9F=E8=BF=9B=E6=96=AD=E7=82=B9?= =?UTF-8?q?=E9=87=8D=E8=AE=AD=E7=9A=84=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 50 ++-- .../drivers/paddle_driver/paddle_driver.py | 231 ++++++++++++++---- .../drivers/paddle_driver/single_device.py | 35 ++- fastNLP/core/drivers/paddle_driver/utils.py | 69 ++++-- 4 files changed, 272 insertions(+), 113 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 0fd74795..3635ae14 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -10,6 +10,7 @@ from .utils import ( _MODE_PARAMETER, get_device_from_visible, reset_seed, + replace_sampler ) from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -19,8 +20,13 @@ from fastNLP.core.utils import ( paddle_move_data_to_device, is_in_paddle_dist, ) -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES +from fastNLP.core.samplers import ( + ReproducibleIterator, + RandomSampler, + UnrepeatedDistributedSampler, + re_instantiate_sampler, +) +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -314,23 +320,15 @@ class PaddleFleetDriver(PaddleDriver): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], reproducible: bool = False, sampler_or_batch_sampler=None): - # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." if isinstance(dist, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist - return dataloader - - # paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 - # 但是其子类 DistributedBatchSampler 却有 shuffle 成员 - # 因此用 type() 进行严格的判断 - if type(dataloader.batch_sampler) == BatchSampler: - shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) - else: - shuffle = dataloader.batch_sampler.shuffle + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) # trainer, evaluator + # 自己初始化了分布式,什么都不做 if dist is None: if reproducible: raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " @@ -339,40 +337,40 @@ class PaddleFleetDriver(PaddleDriver): return dataloader # trainer elif dist == "dist": + args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): - dataloader.batch_sampler.sampler.set_distributed( + if isinstance(args.sampler, ReproducibleIterator): + sampler = re_instantiate_sampler(args.sampler) + sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, pad=True ) - return dataloader + return replace_sampler(dataloader, sampler) else: sampler = RandomSampler( - dataset=dataloader.dataset, - shuffle=shuffle, - seed=int(os.environ.get("FASTNLP_SEED", 0)) + dataset=args.dataset, + shuffle=args.shuffle, + seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) ) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, pad=True ) - dataloader.batch_sampler.sampler = sampler - return dataloader + return replace_sampler(dataloader, sampler) # evaluator elif dist == "unrepeatdist": + args = self.get_dataloader_args(dataloader) sampler = UnrepeatedDistributedSampler( - dataset=dataloader.dataset, - shuffle=shuffle, - seed=int(os.environ.get("FASTNLP_SEED", 0)) + dataset=args.dataset, + shuffle=args.shuffle, ) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank ) - dataloader.batch_sampler.sampler = sampler - return dataloader + return replace_sampler(dataloader, sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 84ce6ec2..69f9ed44 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -1,21 +1,31 @@ import os import random -from typing import Union, Optional, Callable, Dict +from typing import Union, Optional, Dict +from pathlib import Path from functools import partial +from dataclasses import dataclass import numpy as np -from .utils import _build_fp16_env +from .utils import _build_fp16_env, optimizer_state_to_device from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.envs import rank_zero_call -from fastNLP.envs import FASTNLP_SEED_WORKERS +from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler if _NEED_IMPORT_PADDLE: import paddle - from paddle.io import DataLoader, IterableDataset + from paddle.io import ( + DataLoader, + IterableDataset, + Dataset, + Sampler, + BatchSampler, + RandomSampler, + ) from paddle.optimizer import Optimizer _reduces = { @@ -69,6 +79,8 @@ class PaddleDriver(Driver): # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; if isinstance(dataloader.dataset, IterableDataset): raise TypeError("`IterableDataset` is not allowed.") + if dataloader.batch_sampler is None and dataloader.batch_size is None: + raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") else: if not isinstance(dataloader, Dict): raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") @@ -79,6 +91,9 @@ class PaddleDriver(Driver): f"type, not {type(each_dataloader)}.") if isinstance(each_dataloader.dataset, IterableDataset): raise TypeError("`IterableDataset` is not allowed.") + if dataloader.batch_sampler is None and dataloader.batch_size is None: + raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " + f"`batch_sampler` and `batch_size` should be set.") @staticmethod def _check_optimizer_legality(optimizers): @@ -153,45 +168,53 @@ class PaddleDriver(Driver): getattr(self.model, mode)() @rank_zero_call - def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): + def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; 如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; :param filepath: 保存文件的文件位置(需要包括文件名); - :param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; - :param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); + :param only_state_dict: 是否只保存模型的 `state_dict`; + :param kwargs: + :return: """ - if model_save_fn is not None: - model_save_fn(filepath) + model = self.unwrap_model() + + if only_state_dict: + states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} + paddle.save(states, filepath) else: - model = self.unwrap_model() - if only_state_dict: - paddle.save(model.state_dict(), filepath) + # paddle 在保存整个模型时需要传入额外参数 + input_spec = kwargs.get("input_spec", None) + if input_spec is None: + raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") + if self.model_device is not None: + if not self.is_distributed(): + self.move_model_to_device(model, "cpu") + paddle.jit.save(model, filepath, input_spec) + if not self.is_distributed(): + self.move_model_to_device(model, self.model_device) else: - input_spec = kwargs.get("input_spec", None) - if input_spec is None: - raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") paddle.jit.save(model, filepath, input_spec) - @staticmethod - @rank_zero_call - def load_model(filepath: str, load_dict: bool = True): + def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; :param filepath: 需要被加载的对象的文件位置(需要包括文件名); :param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, 即保存了整个模型时,这个参数必须也为False - :return: 返回加载指定文件后的结果; + :param kwargs: + :return: """ - if load_dict: - return paddle.load(filepath) + model = self.unwrap_model() + if only_state_dict: + model.load_dict(paddle.load(filepath)) else: - return paddle.jit.load(filepath) + model.load_dict(paddle.jit.load(filepath).state_dict()) @rank_zero_call - def save(self, folder, states: Dict): + def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): r""" 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver @@ -203,48 +226,110 @@ class PaddleDriver(Driver): :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 传入的值保持一致。 + :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 + :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 + :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 + :return: """ - # 1. 保存模型的状态; - model = self.unwrap_model() - model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} - # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; - states["model_state_dict"] = model_state_dict + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + states['sampler_states'] = sampler.state_dict() + else: + raise RuntimeError( + 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') - # 2. 保存 optimizers 的状态; + # 2. 保存模型的状态; + if should_save_model: + model = self.unwrap_model() + if only_state_dict: + model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} + paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) + logger.debug("Save model state dict") + else: + input_spec = kwargs.get("input_spec", None) + if input_spec is None: + raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") + paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) + logger.debug("Save model") + + # 3. 保存 optimizers 的状态; optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: Optimizer = self.optimizers[i] optimizer_state = optimizer.state_dict() - optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} + optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; - states["optimizers_state_dict"] = optimizers_state_dict - - paddle.save(states, folder) - def load(self, filepath) -> Dict: - r""" - 断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; - driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 - 因此 save 函数和 load 函数的接受和返回值应该是对应的; - - 该函数需要在所有 rank 上执行。 + logger.debug("Save optimizer state dict") + states["optimizers_state_dict"] = optimizers_state_dict + paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) - :param filepath: 保存断点重训的状态的文件名; - :return: 需要返回 save 函数输入的 states 内容; - """ - states = paddle.load(filepath) + def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + + states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; optimizers_state_dict = states["optimizers_state_dict"] for i in range(len(self.optimizers)): - optimizer: paddle.optimizer.Optimizer = self.optimizers[i] + optimizer: Optimizer = self.optimizers[i] optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) + logger.debug("Load optimizer state dict.") # 2. 加载模型状态; - model = self.unwrap_model() - model.load_dict(states["model_state_dict"]) + if should_load_model: + model = self.unwrap_model() + if only_state_dict: + res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) + model.load_dict(res) + logger.debug("Load model state dict.") + else: + model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) + logger.debug("Load model.") + + # 3. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + sampler = dataloader_args.sampler + if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): + # 说明这里需要使用 ReproduceSampler 来弄一下了 + if self.is_distributed(): + raise RuntimeError( + "It is not allowed to use single device checkpoint retraining before but ddp now.") + sampler = ReproducibleBatchSampler( + batch_sampler=sampler, + batch_size=dataloader_args.batch_sampler.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(states['sampler_states']) + + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch - self.barrier() return states def get_evaluate_context(self): @@ -313,3 +398,53 @@ class PaddleDriver(Driver): """ if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): dataloader.batch_sampler.set_epoch(cur_epoch_idx) + + @staticmethod + def get_dataloader_args(dataloader: "DataLoader"): + """ + 获取 dataloader 的 shuffle 和 drop_last 属性; + """ + + @dataclass + class Res: + dataset: Optional[Dataset] = None + batch_sampler: Optional[BatchSampler] = None + sampler: Optional[Sampler] = None + batch_size: Optional[int] = None + shuffle: Optional[bool] = None + drop_last: Optional[bool] = None + + res = Res() + + # paddle 的 DataLoader 一定会有 dataset 属性; + res.dataset = dataloader.dataset + + if dataloader.batch_sampler is not None: + res.batch_sampler = dataloader.batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + res.batch_size = getattr(dataloader.batch_sampler, "batch_size") + # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; + else: + dataloader_iter = iter(dataloader) + pre_sample = next(dataloader_iter) + res.batch_size = pre_sample.shape[0] + + if hasattr(dataloader.batch_sampler, "sampler"): + res.sampler = dataloader.batch_sampler.sampler + if hasattr(dataloader.batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): + res.shuffle = True + else: + res.shuffle = False + else: + res.sampler = None + res.shuffle = False + + if hasattr(dataloader.batch_sampler, "drop_last"): + res.drop_last = getattr(dataloader.batch_sampler, "drop_last") + # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; + else: + res.drop_last = False + + return res diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 97f14bb6..75d80478 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -2,6 +2,7 @@ import os from typing import Optional, Dict, Union from .paddle_driver import PaddleDriver +from .utils import replace_batch_sampler, replace_sampler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES from fastNLP.core.utils import ( @@ -10,7 +11,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -93,11 +94,8 @@ class PaddleSingleDriver(PaddleDriver): self._test_signature_fn = model.forward def setup(self): - user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] device_id = get_paddle_device_id(self.model_device) - if user_visible_devices is not None and user_visible_devices != "": - # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES - device_id = user_visible_devices.split(",")[device_id] + device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) paddle.device.set_device("gpu:0") self.model.to("gpu:0") @@ -145,26 +143,25 @@ class PaddleSingleDriver(PaddleDriver): assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." if isinstance(dist, ReproducibleBatchSampler): - dataloader.batch_sampler = dist - return dataloader - if isinstance(dist, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist - return dataloader + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleIterator): + return replace_sampler(dataloader, dist) if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): - return dataloader + args = self.get_dataloader_args(dataloader) + if isinstance(args.sampler, ReproducibleIterator): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - return dataloader + batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) else: - # TODO batch_sampler = ReproducibleBatchSampler( - batch_sampler=dataloader.batch_sampler, - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.drop_last + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last ) - dataloader.batch_sampler = batch_sampler - return dataloader + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index ebe0f6c5..a8121879 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -9,7 +9,7 @@ from enum import IntEnum from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from fastNLP.core.utils import get_paddle_device_id, auto_param_call +from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -272,11 +272,9 @@ def get_device_from_visible(device: Union[str, int]): else: # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) - if user_visible_devices is not None and user_visible_devices != "": - # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES - idx = user_visible_devices.split(",")[idx] - else: - idx = str(idx) + if user_visible_devices is None: + raise RuntimeError("This situation cannot happen, please report a bug to us.") + idx = user_visible_devices.split(",")[idx] cuda_visible_devices_list = cuda_visible_devices.split(',') assert idx in cuda_visible_devices_list, "Can't find "\ @@ -285,31 +283,44 @@ def get_device_from_visible(device: Union[str, int]): res = cuda_visible_devices_list.index(idx) return res -def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): - # 拿到实例属性; +def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): + """ + 利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 + 考虑了用户自己定制了 DataLoader 的情形。 + """ + # 拿到非下划线开头的实例属性; instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} - # 拿到 dataloader '__init__' 函数的默认函数签名; + # 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 init_params = dict(inspect.signature(dataloader.__init__).parameters) # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader - # 中寻找; + # 中寻找;VAR_KEYWORD 代表 **kwargs has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) del init_params["self"] # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; + # 将同时在实例名和参数名中出现且不是默认值的参数收集起来 non_default_params = {name for name, p in init_params.items() if name in instance_attrs and p.default != instance_attrs[name]} # add `dataset` as it might have been replaced with `*args` non_default_params.add("dataset") + # 收集不是默认值的参数和它的值 reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) - + # persistent_workers 在类中的对应成员带有下划线,因此添加进来 + reconstruct_args.update({ + "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, + "persistent_workers": dataloader._persistent_workers, + }) + + # POSITIONAL_OR_KEYWORD 代表一般的参数 + # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 + # 也即它们没有在初始化函数和实例成员中同时出现 required_args = { p.name for p in init_params.values() @@ -323,12 +334,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " + f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedBatchSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; @@ -340,12 +348,33 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " + f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " - "manually add the `DistributedBatchSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." ) return type(dataloader)(**reconstruct_args) + +def replace_sampler(dataloader, new_sampler): + """ + 使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 + """ + new_batch_sampler = BatchSampler( + dataset=dataloader.batch_sampler.dataset, + sampler=new_sampler, + shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler), + batch_size=dataloader.batch_sampler.batch_size, + drop_last=dataloader.batch_sampler.drop_last + ) + return replace_batch_sampler(dataloader, new_batch_sampler) + +def optimizer_state_to_device(state, device): + new_state = {} + for name, param in state.items(): + if isinstance(param, dict): + new_state[name] = optimizer_state_to_device(param, device) + elif isinstance(param, paddle.Tensor): + new_state[name] = paddle_to(param, device).clone() + else: + new_state[name] = param + return new_state From ebfa118ff2e44a8bfff8a169c5be4500c2c194dc Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 15:07:52 +0000 Subject: [PATCH 07/26] =?UTF-8?q?PaddleDriver=E7=9A=84=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paddle_driver/test_paddle_driver.py | 120 +++++++----------- 1 file changed, 44 insertions(+), 76 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_paddle_driver.py b/tests/core/drivers/paddle_driver/test_paddle_driver.py index 9febc27d..9308785a 100644 --- a/tests/core/drivers/paddle_driver/test_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_paddle_driver.py @@ -1,75 +1,28 @@ -import unittest - -import torch +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver -import paddle -from paddle.io import Dataset, DataLoader - -class Net(paddle.nn.Layer): - def __init__(self): - super(Net, self).__init__() - - self.fc1 = paddle.nn.Linear(784, 64) - self.fc2 = paddle.nn.Linear(64, 32) - self.fc3 = paddle.nn.Linear(32, 10) - self.fc4 = paddle.nn.Linear(10, 10) - - def forward(self, x): - - x = self.fc1(x) - x = self.fc2(x) - x = self.fc3(x) - x = self.fc4(x) - - return x - - -class PaddleDataset(Dataset): - def __init__(self): - super(PaddleDataset, self).__init__() - self.items = [paddle.rand((3, 4)) for i in range(320)] - - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.items[idx] - - -class TorchNet(torch.nn.Module): - def __init__(self): - super(TorchNet, self).__init__() - - self.torch_fc1 = torch.nn.Linear(10, 10) - self.torch_softmax = torch.nn.Softmax(0) - self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) - self.torch_tensor = torch.ones(3, 3) - self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) - - -class TorchDataset(torch.utils.data.Dataset): - def __init__(self): - super(TorchDataset, self).__init__() - self.items = [torch.ones(3, 4) for i in range(320)] - - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.items[idx] +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +import torch +import paddle +from paddle.io import DataLoader -class PaddleDriverTestCase(unittest.TestCase): +class TestPaddleDriverFunctions: """ - PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 + PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 """ - def setUp(self): - model = Net() + @classmethod + def setup_class(self): + model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleDriver(model) - def test_check_single_optimizer_legacy(self): + def test_check_single_optimizer_legality(self): """ 测试传入单个optimizer时的表现 """ @@ -80,12 +33,12 @@ class PaddleDriverTestCase(unittest.TestCase): self.driver.set_optimizers(optimizer) - optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) + optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) # 传入torch的optimizer时,应该报错ValueError with self.assertRaises(ValueError) as cm: self.driver.set_optimizers(optimizer) - def test_check_optimizers_legacy(self): + def test_check_optimizers_legality(self): """ 测试传入optimizer list的表现 """ @@ -99,22 +52,27 @@ class PaddleDriverTestCase(unittest.TestCase): self.driver.set_optimizers(optimizers) optimizers += [ - torch.optim.Adam(TorchNet().parameters(), 0.01) + torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) ] with self.assertRaises(ValueError) as cm: self.driver.set_optimizers(optimizers) - def test_check_dataloader_legacy_in_train(self): + def test_check_dataloader_legality_in_train(self): """ 测试is_train参数为True时,_check_dataloader_legality函数的表现 """ - dataloader = paddle.io.DataLoader(PaddleDataset()) + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + with self.assertRaises(ValueError) as cm: + PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) + # 创建torch的dataloader dataloader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) with self.assertRaises(ValueError) as cm: @@ -125,21 +83,31 @@ class PaddleDriverTestCase(unittest.TestCase): 测试is_train参数为False时,_check_dataloader_legality函数的表现 """ # 此时传入的应该是dict - dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset()) + } + PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + } PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 传入的不是dict,应该报错 - dataloader = paddle.io.DataLoader(PaddleDataset()) + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) with self.assertRaises(ValueError) as cm: PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 创建torch的dataloader train_loader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) test_loader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) dataloader = {"train": train_loader, "test": test_loader} @@ -240,7 +208,7 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_deterministic_dataloader(dataloader) def test_set_sampler_epoch(self): @@ -249,7 +217,7 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_sampler_epoch(dataloader, 0) def test_get_dataloader_args(self): @@ -258,5 +226,5 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) res = PaddleDriver.get_dataloader_args(dataloader) \ No newline at end of file From d1a589147afbca0233dfdc14eb114016d3016432 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 11 Apr 2022 14:08:25 +0000 Subject: [PATCH 08/26] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/single_device.py | 2 +- .../paddle_driver/test_paddle_driver.py | 230 -------------- .../paddle_driver/test_single_device.py | 289 ++++++++++++++++-- 3 files changed, 258 insertions(+), 263 deletions(-) delete mode 100644 tests/core/drivers/paddle_driver/test_paddle_driver.py diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 75d80478..cee1ebfa 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -23,7 +23,7 @@ __all__ = [ ] class PaddleSingleDriver(PaddleDriver): - def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs): + def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) if device is None: diff --git a/tests/core/drivers/paddle_driver/test_paddle_driver.py b/tests/core/drivers/paddle_driver/test_paddle_driver.py deleted file mode 100644 index 9308785a..00000000 --- a/tests/core/drivers/paddle_driver/test_paddle_driver.py +++ /dev/null @@ -1,230 +0,0 @@ -import os -import pytest -os.environ["FASTNLP_BACKEND"] = "paddle" - -from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleNormalDataset -from tests.helpers.datasets.torch_data import TorchNormalDataset -from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 - -import torch -import paddle -from paddle.io import DataLoader - -class TestPaddleDriverFunctions: - """ - PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 - """ - - @classmethod - def setup_class(self): - model = PaddleNormalModel_Classification_1(10, 32) - self.driver = PaddleDriver(model) - - def test_check_single_optimizer_legality(self): - """ - 测试传入单个optimizer时的表现 - """ - optimizer = paddle.optimizer.Adam( - parameters=self.driver.model.parameters(), - learning_rate=0.01 - ) - - self.driver.set_optimizers(optimizer) - - optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) - # 传入torch的optimizer时,应该报错ValueError - with self.assertRaises(ValueError) as cm: - self.driver.set_optimizers(optimizer) - - def test_check_optimizers_legality(self): - """ - 测试传入optimizer list的表现 - """ - optimizers = [ - paddle.optimizer.Adam( - parameters=self.driver.model.parameters(), - learning_rate=0.01 - ) for i in range(10) - ] - - self.driver.set_optimizers(optimizers) - - optimizers += [ - torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) - ] - - with self.assertRaises(ValueError) as cm: - self.driver.set_optimizers(optimizers) - - def test_check_dataloader_legality_in_train(self): - """ - 测试is_train参数为True时,_check_dataloader_legality函数的表现 - """ - dataloader = paddle.io.DataLoader(PaddleNormalDataset()) - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) - - # batch_size 和 batch_sampler 均为 None 的情形 - dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) - with self.assertRaises(ValueError) as cm: - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) - - # 创建torch的dataloader - dataloader = torch.utils.data.DataLoader( - TorchNormalDataset(), - batch_size=32, shuffle=True - ) - with self.assertRaises(ValueError) as cm: - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) - - def test_check_dataloader_legacy_in_test(self): - """ - 测试is_train参数为False时,_check_dataloader_legality函数的表现 - """ - # 此时传入的应该是dict - dataloader = { - "train": paddle.io.DataLoader(PaddleNormalDataset()), - "test":paddle.io.DataLoader(PaddleNormalDataset()) - } - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) - - # batch_size 和 batch_sampler 均为 None 的情形 - dataloader = { - "train": paddle.io.DataLoader(PaddleNormalDataset()), - "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) - } - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) - - # 传入的不是dict,应该报错 - dataloader = paddle.io.DataLoader(PaddleNormalDataset()) - with self.assertRaises(ValueError) as cm: - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) - - # 创建torch的dataloader - train_loader = torch.utils.data.DataLoader( - TorchNormalDataset(), - batch_size=32, shuffle=True - ) - test_loader = torch.utils.data.DataLoader( - TorchNormalDataset(), - batch_size=32, shuffle=True - ) - dataloader = {"train": train_loader, "test": test_loader} - with self.assertRaises(ValueError) as cm: - PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) - - def test_tensor_to_numeric(self): - """ - 测试tensor_to_numeric函数 - """ - # 单个张量 - tensor = paddle.to_tensor(3) - res = PaddleDriver.tensor_to_numeric(tensor) - self.assertEqual(res, 3) - - tensor = paddle.rand((3, 4)) - res = PaddleDriver.tensor_to_numeric(tensor) - self.assertListEqual(res, tensor.tolist()) - - # 张量list - tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] - res = PaddleDriver.tensor_to_numeric(tensor_list) - self.assertTrue(res, list) - tensor_list = [t.tolist() for t in tensor_list] - self.assertListEqual(res, tensor_list) - - # 张量tuple - tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) - res = PaddleDriver.tensor_to_numeric(tensor_tuple) - self.assertTrue(res, tuple) - tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) - self.assertTupleEqual(res, tensor_tuple) - - # 张量dict - tensor_dict = { - "tensor": paddle.rand((3, 4)), - "list": [paddle.rand((6, 4, 2)) for i in range(10)], - "dict":{ - "list": [paddle.rand((6, 4, 2)) for i in range(10)], - "tensor": paddle.rand((3, 4)) - }, - "int": 2, - "string": "test string" - } - - res = PaddleDriver.tensor_to_numeric(tensor_dict) - self.assertIsInstance(res, dict) - self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) - self.assertIsInstance(res["list"], list) - for r, d in zip(res["list"], tensor_dict["list"]): - self.assertListEqual(r, d.tolist()) - self.assertIsInstance(res["int"], int) - self.assertIsInstance(res["string"], str) - self.assertIsInstance(res["dict"], dict) - self.assertIsInstance(res["dict"]["list"], list) - for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): - self.assertListEqual(r, d.tolist()) - self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) - - def test_set_model_mode(self): - """ - 测试set_model_mode函数 - """ - self.driver.set_model_mode("train") - self.assertTrue(self.driver.model.training) - self.driver.set_model_mode("eval") - self.assertFalse(self.driver.model.training) - # 应该报错 - with self.assertRaises(AssertionError) as cm: - self.driver.set_model_mode("test") - - def test_move_model_to_device_cpu(self): - """ - 测试move_model_to_device函数 - """ - PaddleDriver.move_model_to_device(self.driver.model, "cpu") - self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) - - def test_move_model_to_device_gpu(self): - """ - 测试move_model_to_device函数 - """ - PaddleDriver.move_model_to_device(self.driver.model, "gpu:0") - self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) - self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) - - def test_worker_init_function(self): - """ - 测试worker_init_function - """ - # 先确保不影响运行 - # TODO:正确性 - PaddleDriver.worker_init_function(0) - - def test_set_deterministic_dataloader(self): - """ - 测试set_deterministic_dataloader - """ - # 先确保不影响运行 - # TODO:正确性 - dataloader = DataLoader(PaddleNormalDataset()) - self.driver.set_deterministic_dataloader(dataloader) - - def test_set_sampler_epoch(self): - """ - 测试set_sampler_epoch - """ - # 先确保不影响运行 - # TODO:正确性 - dataloader = DataLoader(PaddleNormalDataset()) - self.driver.set_sampler_epoch(dataloader, 0) - - def test_get_dataloader_args(self): - """ - 测试get_dataloader_args - """ - # 先确保不影响运行 - # TODO:正确性 - dataloader = DataLoader(PaddleNormalDataset()) - res = PaddleDriver.get_dataloader_args(dataloader) \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 33662d7f..791b1203 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,20 +1,20 @@ +import os +os.environ["FASTNLP_BACKEND"] = "paddle" import pytest -from fastNLP.envs.set_backend import set_env -from fastNLP.envs.set_env_on_import import set_env_on_import_paddle - -set_env_on_import_paddle() -set_env("paddle") -import paddle -from paddle.io import DataLoader, BatchSampler - from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers.reproducible_sampler import RandomSampler from fastNLP.core.samplers import ReproducibleBatchSampler -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification -from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from fastNLP.core import synchronize_safe_rm +import paddle +from paddle.io import DataLoader, BatchSampler +import torch + ############################################################################ # @@ -26,32 +26,35 @@ def generate_random_driver(features, labels): """ 生成driver """ - model = PaddleNormalModel_Classification(labels, features) + model = PaddleNormalModel_Classification_1(labels, features) opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) - driver = PaddleSingleDriver(model) + driver = PaddleSingleDriver(model, device="cpu") driver.set_optimizers(opt) return driver @pytest.fixture def prepare_test_save_load(): - dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8) + dataset = PaddleRandomMaxDataset(320, 10) dataloader = DataLoader(dataset, batch_size=32) - driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader -def test_save_and_load(prepare_test_save_load): +@pytest.mark.parametrize("reproducible", [True, False]) +@pytest.mark.parametrize("only_state_dict", [True, False]) +def test_save_and_load(prepare_test_save_load, reproducible, only_state_dict): """ 测试save和load函数 TODO optimizer的state_dict为空,暂时不测试 """ try: - path = "model.pdparams" + path = "model.ckp" driver1, driver2, dataloader = prepare_test_save_load + dataloader = driver1.set_dist_repro_dataloader(dataloader, "dist", reproducible) - driver1.save(path, {}) - driver2.load(path) + driver1.save(path, {}, dataloader, only_state_dict, should_save_model=True) + driver2.load(path, dataloader, only_state_dict, should_load_model=True) for batch in dataloader: res1 = driver1.validate_step(batch) @@ -67,11 +70,11 @@ def test_save_and_load_state_dict(prepare_test_save_load): TODO optimizer的state_dict为空,暂时不测试 """ try: - path = "model.pdparams" + path = "dict" driver1, driver2, dataloader = prepare_test_save_load driver1.save_model(path) - driver2.model.load_dict(driver2.load_model(path)) + driver2.load_model(path) for batch in dataloader: res1 = driver1.validate_step(batch) @@ -87,11 +90,11 @@ def test_save_and_load_whole_model(prepare_test_save_load): TODO optimizer的state_dict为空,暂时不测试 """ try: - path = "model.pdparams" + path = "model" driver1, driver2, dataloader = prepare_test_save_load driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) - driver2.model = driver2.load_model(path, load_dict=False) + driver2.load_model(path, only_state_dict=False) for batch in dataloader: res1 = driver1.validate_step(batch) @@ -99,7 +102,9 @@ def test_save_and_load_whole_model(prepare_test_save_load): assert paddle.equal_all(res1["pred"], res2["pred"]) finally: - synchronize_safe_rm(path) + synchronize_safe_rm(path + ".pdiparams") + synchronize_safe_rm(path + ".pdiparams.info") + synchronize_safe_rm(path + ".pdmodel") class TestSingleDeviceFunction: @@ -109,8 +114,8 @@ class TestSingleDeviceFunction: @classmethod def setup_class(cls): - model = PaddleNormalModel_Classification(10, 784) - cls.driver = PaddleSingleDriver(model) + model = PaddleNormalModel_Classification_1(10, 784) + cls.driver = PaddleSingleDriver(model, device="gpu") def test_unwrap_model(self): """ @@ -129,7 +134,7 @@ class TestSingleDeviceFunction: """ 测试get_model_device """ - self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu") + self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "cpu") device = self.driver.get_model_device() assert device == "cpu", device @@ -137,7 +142,7 @@ class TestSingleDeviceFunction: """ 测试get_model_device """ - self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0") + self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "gpu:0") device = self.driver.get_model_device() assert device == "gpu:0", device @@ -152,8 +157,11 @@ class TestSingleDeviceFunction: self.driver.move_data_to_device(paddle.rand((32, 64))) @pytest.mark.parametrize( - "dist_sampler", - ["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] + "dist_sampler", [ + "dist", + ReproducibleBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), + RandomSampler(PaddleRandomMaxDataset(320, 10)) + ] ) @pytest.mark.parametrize( "reproducible", @@ -161,8 +169,225 @@ class TestSingleDeviceFunction: ) def test_repalce_sampler(self, dist_sampler, reproducible): """ - 测试replace_sampler函数 + 测试set_dist_repro_dataloader函数 + """ + dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) + + res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) + +class TestPaddleDriverFunctions: + """ + 使用 PaddleSingleDriver 测试基类的函数 + """ + + @classmethod + def setup_class(self): + model = PaddleNormalModel_Classification_1(10, 32) + self.driver = PaddleSingleDriver(model, device="gpu") + + def test_check_single_optimizer_legality(self): + """ + 测试传入单个optimizer时的表现 + """ + optimizer = paddle.optimizer.Adam( + parameters=self.driver.model.parameters(), + learning_rate=0.01 + ) + + self.driver.set_optimizers(optimizer) + + optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) + # 传入torch的optimizer时,应该报错ValueError + with self.assertRaises(ValueError) as cm: + self.driver.set_optimizers(optimizer) + + def test_check_optimizers_legality(self): + """ + 测试传入optimizer list的表现 + """ + optimizers = [ + paddle.optimizer.Adam( + parameters=self.driver.model.parameters(), + learning_rate=0.01 + ) for i in range(10) + ] + + self.driver.set_optimizers(optimizers) + + optimizers += [ + torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) + ] + + with self.assertRaises(ValueError) as cm: + self.driver.set_optimizers(optimizers) + + def test_check_dataloader_legality_in_train(self): + """ + 测试is_train参数为True时,_check_dataloader_legality函数的表现 + """ + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + with self.assertRaises(ValueError) as cm: + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + + # 创建torch的dataloader + dataloader = torch.utils.data.DataLoader( + TorchNormalDataset(), + batch_size=32, shuffle=True + ) + with self.assertRaises(ValueError) as cm: + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + + def test_check_dataloader_legacy_in_test(self): + """ + 测试is_train参数为False时,_check_dataloader_legality函数的表现 + """ + # 此时传入的应该是dict + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset()) + } + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + } + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + # 传入的不是dict,应该报错 + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) + with self.assertRaises(ValueError) as cm: + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + # 创建torch的dataloader + train_loader = torch.utils.data.DataLoader( + TorchNormalDataset(), + batch_size=32, shuffle=True + ) + test_loader = torch.utils.data.DataLoader( + TorchNormalDataset(), + batch_size=32, shuffle=True + ) + dataloader = {"train": train_loader, "test": test_loader} + with self.assertRaises(ValueError) as cm: + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + def test_tensor_to_numeric(self): + """ + 测试tensor_to_numeric函数 """ - dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) + # 单个张量 + tensor = paddle.to_tensor(3) + res = PaddleSingleDriver.tensor_to_numeric(tensor) + self.assertEqual(res, 3) + + tensor = paddle.rand((3, 4)) + res = PaddleSingleDriver.tensor_to_numeric(tensor) + self.assertListEqual(res, tensor.tolist()) + + # 张量list + tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] + res = PaddleSingleDriver.tensor_to_numeric(tensor_list) + self.assertTrue(res, list) + tensor_list = [t.tolist() for t in tensor_list] + self.assertListEqual(res, tensor_list) + + # 张量tuple + tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) + res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) + self.assertTrue(res, tuple) + tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) + self.assertTupleEqual(res, tensor_tuple) + + # 张量dict + tensor_dict = { + "tensor": paddle.rand((3, 4)), + "list": [paddle.rand((6, 4, 2)) for i in range(10)], + "dict":{ + "list": [paddle.rand((6, 4, 2)) for i in range(10)], + "tensor": paddle.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + + res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) + self.assertIsInstance(res, dict) + self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) + self.assertIsInstance(res["list"], list) + for r, d in zip(res["list"], tensor_dict["list"]): + self.assertListEqual(r, d.tolist()) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): + self.assertListEqual(r, d.tolist()) + self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) + + def test_set_model_mode(self): + """ + 测试set_model_mode函数 + """ + self.driver.set_model_mode("train") + self.assertTrue(self.driver.model.training) + self.driver.set_model_mode("eval") + self.assertFalse(self.driver.model.training) + # 应该报错 + with self.assertRaises(AssertionError) as cm: + self.driver.set_model_mode("test") + + def test_move_model_to_device_cpu(self): + """ + 测试move_model_to_device函数 + """ + PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") + self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) - res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) \ No newline at end of file + def test_move_model_to_device_gpu(self): + """ + 测试move_model_to_device函数 + """ + PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu:0") + self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) + self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) + + def test_worker_init_function(self): + """ + 测试worker_init_function + """ + # 先确保不影响运行 + # TODO:正确性 + PaddleSingleDriver.worker_init_function(0) + + def test_set_deterministic_dataloader(self): + """ + 测试set_deterministic_dataloader + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(PaddleNormalDataset()) + self.driver.set_deterministic_dataloader(dataloader) + + def test_set_sampler_epoch(self): + """ + 测试set_sampler_epoch + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(PaddleNormalDataset()) + self.driver.set_sampler_epoch(dataloader, 0) + + def test_get_dataloader_args(self): + """ + 测试get_dataloader_args + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(PaddleNormalDataset()) + res = PaddleSingleDriver.get_dataloader_args(dataloader) \ No newline at end of file From 2366bc320bcc6d5bbd2b28e703572a3b8b71480d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 11 Apr 2022 15:04:39 +0000 Subject: [PATCH 09/26] =?UTF-8?q?=E8=B7=9F=E8=BF=9B=E6=96=AD=E7=82=B9?= =?UTF-8?q?=E9=87=8D=E8=AE=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 59 ++++++++++++---- .../drivers/paddle_driver/paddle_driver.py | 55 +++++++-------- .../drivers/paddle_driver/single_device.py | 69 +++++++------------ 3 files changed, 94 insertions(+), 89 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 2a1d5228..86198959 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -10,7 +10,8 @@ from .utils import ( _MODE_PARAMETER, get_device_from_visible, reset_seed, - replace_sampler + replace_sampler, + replace_batch_sampler, ) from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -23,10 +24,12 @@ from fastNLP.core.utils import ( from fastNLP.core.samplers import ( RandomBatchSampler, ReproducibleSampler, - ReproducibleIterator, + ReproducibleBatchSampler, RandomSampler, - UnrepeatedDistributedSampler, + UnrepeatedSampler, + UnrepeatedSequentialSampler, re_instantiate_sampler, + conversion_between_reproducible_and_unrepeated_sampler, ) from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger @@ -261,7 +264,6 @@ class PaddleFleetDriver(PaddleDriver): 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 根据 paddle 设置的环境变量来获得各种属性 """ - print("set_from_env") self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() @@ -325,23 +327,50 @@ class PaddleFleetDriver(PaddleDriver): # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, ReproducibleIterator): - dist = re_instantiate_sampler(dist) + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; + # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; + if isinstance(dist, ReproducibleBatchSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; # trainer, evaluator - # 自己初始化了分布式,什么都不做 if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " + raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: + if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleSampler): + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) return dataloader # trainer elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.sampler, ReproducibleIterator): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + batch_sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -364,10 +393,14 @@ class PaddleFleetDriver(PaddleDriver): # evaluator elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) - sampler = UnrepeatedDistributedSampler( - dataset=args.dataset, - shuffle=args.shuffle, - ) + if isinstance(args.sampler, ReproducibleSampler): + sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) + elif not isinstance(args.sampler, UnrepeatedSampler): + sampler = UnrepeatedSequentialSampler( + dataset=args.dataset + ) + else: + sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 69f9ed44..95e6215e 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -14,7 +14,7 @@ from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler if _NEED_IMPORT_PADDLE: import paddle @@ -178,11 +178,13 @@ class PaddleDriver(Driver): :param kwargs: :return: """ + debug = kwargs.get("debug", False) model = self.unwrap_model() - if only_state_dict: states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} paddle.save(states, filepath) + if debug: + logger.debug("Save model state dict.") else: # paddle 在保存整个模型时需要传入额外参数 input_spec = kwargs.get("input_spec", None) @@ -196,6 +198,8 @@ class PaddleDriver(Driver): self.move_model_to_device(model, self.model_device) else: paddle.jit.save(model, filepath, input_spec) + if debug: + logger.debug("Save model.") def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" @@ -207,11 +211,16 @@ class PaddleDriver(Driver): :param kwargs: :return: """ + debug = kwargs.get("debug", False) model = self.unwrap_model() if only_state_dict: model.load_dict(paddle.load(filepath)) + if debug: + logger.debug("Load model state dict.") else: model.load_dict(paddle.jit.load(filepath).state_dict()) + if debug: + logger.debug("Load model.") @rank_zero_call def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): @@ -252,17 +261,7 @@ class PaddleDriver(Driver): # 2. 保存模型的状态; if should_save_model: - model = self.unwrap_model() - if only_state_dict: - model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} - paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) - logger.debug("Save model state dict") - else: - input_spec = kwargs.get("input_spec", None) - if input_spec is None: - raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") - paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) - logger.debug("Save model") + self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) # 3. 保存 optimizers 的状态; optimizers_state_dict = {} @@ -272,7 +271,7 @@ class PaddleDriver(Driver): optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; - logger.debug("Save optimizer state dict") + logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) @@ -289,30 +288,23 @@ class PaddleDriver(Driver): # 2. 加载模型状态; if should_load_model: - model = self.unwrap_model() - if only_state_dict: - res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) - model.load_dict(res) - logger.debug("Load model state dict.") - else: - model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) - logger.debug("Load model.") + self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) - sampler = dataloader_args.sampler - if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): - # 说明这里需要使用 ReproduceSampler 来弄一下了 - if self.is_distributed(): - raise RuntimeError( - "It is not allowed to use single device checkpoint retraining before but ddp now.") + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + else: sampler = ReproducibleBatchSampler( - batch_sampler=sampler, - batch_size=dataloader_args.batch_sampler.batch_size, + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, + batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) sampler.load_state_dict(states['sampler_states']) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch @@ -420,6 +412,7 @@ class PaddleDriver(Driver): res.dataset = dataloader.dataset if dataloader.batch_sampler is not None: + # 不过在 paddle 中,我们限定了 batch_sampler 不能为 None res.batch_sampler = dataloader.batch_sampler if hasattr(dataloader.batch_sampler, "batch_size"): res.batch_size = getattr(dataloader.batch_sampler, "batch_size") diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 83c3112a..dd5a340a 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -11,7 +11,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -137,55 +137,34 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") -<<<<<<< HEAD - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], -======= - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], ->>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 - reproducible: bool = False, sampler_or_batch_sampler=None): - # 暂时不支持IteratorDataset + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, + reproducible: bool = False): + + # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ - "FastNLP does not support `IteratorDataset` now." + "FastNLP does not support `IteratorDataset` now." + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; if isinstance(dist, ReproducibleBatchSampler): -<<<<<<< HEAD return replace_batch_sampler(dataloader, dist) - elif isinstance(dist, ReproducibleIterator): - return replace_sampler(dataloader, dist) - - if reproducible: - args = self.get_dataloader_args(dataloader) - if isinstance(args.sampler, ReproducibleIterator): - sampler = re_instantiate_sampler(args.sampler) - return replace_sampler(dataloader, sampler) - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) - return replace_batch_sampler(dataloader, batch_sampler) - else: - batch_sampler = ReproducibleBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last -======= - dataloader.batch_sampler = dist - return dataloader - if isinstance(dist, ReproducibleSampler): - dataloader.batch_sampler.sampler = dist - return dataloader + elif isinstance(dist, ReproducibleSampler): + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): - return dataloader - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - return dataloader - else: - # TODO - batch_sampler = ReproducibleBatchSampler( - batch_sampler=dataloader.batch_sampler, - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.drop_last ->>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = ReproducibleBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader From 00b5baf67adda6a61e4f3ede81dd9344b181140d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 11 Apr 2022 17:39:19 +0000 Subject: [PATCH 10/26] =?UTF-8?q?=E6=95=B4=E7=90=86PaddleSingleDriver?= =?UTF-8?q?=E7=9A=84=E9=83=A8=E5=88=86=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 66 ++++----- .../drivers/paddle_driver/single_device.py | 24 ++-- fastNLP/core/drivers/paddle_driver/utils.py | 19 ++- fastNLP/core/utils/paddle_utils.py | 5 +- .../paddle_driver/test_single_device.py | 125 ++++++++---------- 5 files changed, 119 insertions(+), 120 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 95e6215e..89e88aef 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -11,8 +11,13 @@ from .utils import _build_fp16_env, optimizer_state_to_device from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device -from fastNLP.envs import rank_zero_call -from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME +from fastNLP.envs import ( + FASTNLP_SEED_WORKERS, + FASTNLP_MODEL_FILENAME, + FASTNLP_CHECKPOINT_FILENAME, + FASTNLP_GLOBAL_RANK, + rank_zero_call, +) from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler @@ -91,7 +96,7 @@ class PaddleDriver(Driver): f"type, not {type(each_dataloader)}.") if isinstance(each_dataloader.dataset, IterableDataset): raise TypeError("`IterableDataset` is not allowed.") - if dataloader.batch_sampler is None and dataloader.batch_size is None: + if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " f"`batch_sampler` and `batch_size` should be set.") @@ -171,56 +176,45 @@ class PaddleDriver(Driver): def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; - 如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; :param filepath: 保存文件的文件位置(需要包括文件名); - :param only_state_dict: 是否只保存模型的 `state_dict`; + :param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 + 保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 :param kwargs: + input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 + 可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` + 的文档: + https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save :return: """ - debug = kwargs.get("debug", False) model = self.unwrap_model() if only_state_dict: states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} paddle.save(states, filepath) - if debug: - logger.debug("Save model state dict.") else: # paddle 在保存整个模型时需要传入额外参数 input_spec = kwargs.get("input_spec", None) if input_spec is None: raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") - if self.model_device is not None: - if not self.is_distributed(): - self.move_model_to_device(model, "cpu") - paddle.jit.save(model, filepath, input_spec) - if not self.is_distributed(): - self.move_model_to_device(model, self.model_device) - else: - paddle.jit.save(model, filepath, input_spec) - if debug: - logger.debug("Save model.") + paddle.jit.save(model, filepath, input_spec) def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; :param filepath: 需要被加载的对象的文件位置(需要包括文件名); - :param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, - 即保存了整个模型时,这个参数必须也为False + :param only_state_dict: 是否加载state_dict,默认为True。 :param kwargs: :return: """ - debug = kwargs.get("debug", False) model = self.unwrap_model() - if only_state_dict: - model.load_dict(paddle.load(filepath)) - if debug: - logger.debug("Load model state dict.") - else: - model.load_dict(paddle.jit.load(filepath).state_dict()) - if debug: - logger.debug("Load model.") + # paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict + # 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 + dirname, filename = os.path.split(filepath) + if not only_state_dict and dirname == "": + # 如果传入的是单个文件,则加上相对路径 + filepath = os.path.join(".", filepath) + model.load_dict(paddle.load(filepath)) @rank_zero_call def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): @@ -261,7 +255,11 @@ class PaddleDriver(Driver): # 2. 保存模型的状态; if should_save_model: - self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) + self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) + if only_state_dict: + logger.debug("Save model state dict.") + else: + logger.debug("Save model.") # 3. 保存 optimizers 的状态; optimizers_state_dict = {} @@ -288,7 +286,11 @@ class PaddleDriver(Driver): # 2. 加载模型状态; if should_load_model: - self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) + self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) + if only_state_dict: + logger.debug("Load model state dict.") + else: + logger.debug("Load model.") # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) @@ -359,7 +361,7 @@ class PaddleDriver(Driver): `randomness in DataLoaders `_. """ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 - global_rank = rank if rank is not None else rank_zero_call.rank + global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) # TODO gpu process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() # back out the base seed so we can use all the bits diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index dd5a340a..64656124 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -2,7 +2,7 @@ import os from typing import Optional, Dict, Union from .paddle_driver import PaddleDriver -from .utils import replace_batch_sampler, replace_sampler +from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES from fastNLP.core.utils import ( @@ -29,10 +29,7 @@ class PaddleSingleDriver(PaddleDriver): if device is None: raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") - if isinstance(device, int): - self.model_device = get_paddle_gpu_str(device) - else: - self.model_device = device + self.model_device = get_paddle_gpu_str(device) self.local_rank = 0 self.global_rank = 0 @@ -94,11 +91,14 @@ class PaddleSingleDriver(PaddleDriver): self._test_signature_fn = model.forward def setup(self): - device_id = get_paddle_device_id(self.model_device) - device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) - paddle.device.set_device("gpu:0") - self.model.to("gpu:0") + device = self.model_device + if device != "cpu": + device_id = get_paddle_device_id(device) + device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) + device = get_device_from_visible(device, output_type=str) + paddle.device.set_device(device) + self.model.to(device) def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; @@ -131,11 +131,11 @@ class PaddleSingleDriver(PaddleDriver): r""" 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 - 在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` :return: 将移动到指定机器上的 batch 对象返回; """ - return paddle_move_data_to_device(batch, "gpu:0") + device = get_device_from_visible(self.data_device) + return paddle_move_data_to_device(batch, device) def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index a8121879..47c0f1b9 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -255,20 +255,23 @@ def get_host_name_ip(): except: return None -def get_device_from_visible(device: Union[str, int]): +def get_device_from_visible(device: Union[str, int], output_type=int): """ 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 - :param devices:未转化的设备名 + :param devices: 未转化的设备名 + :param output_type: 返回值的类型 :return: 转化后的设备id """ + if output_type not in [int, str]: + raise ValueError("Parameter `output_type` should be one of these types: [int, str]") if device == "cpu": return device cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") idx = get_paddle_device_id(device) if cuda_visible_devices is None or cuda_visible_devices == "": # 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES - return idx + raise RuntimeError("This situation should not happen, please report us this bug.") else: # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) @@ -277,11 +280,13 @@ def get_device_from_visible(device: Union[str, int]): idx = user_visible_devices.split(",")[idx] cuda_visible_devices_list = cuda_visible_devices.split(',') - assert idx in cuda_visible_devices_list, "Can't find "\ - "your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ - % (idx, cuda_visible_devices) + if idx not in cuda_visible_devices_list: + raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") res = cuda_visible_devices_list.index(idx) - return res + if output_type == int: + return res + else: + return f"gpu:{res}" def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): """ diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 51a19e89..1f461e0f 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]): device = device.lower() if device == "cpu": raise ValueError("Cannot get device id from `cpu`.") + elif device == "gpu": + return 0 match_res = re.match(r"gpu:\d+", device) if not match_res: raise ValueError( - "The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" + "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', " + f"not '{device}'" ) device_id = device.split(':', 1)[1] device_id = int(device_id) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 8e21c20f..3d07766a 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,10 +1,11 @@ import os +from numpy import isin os.environ["FASTNLP_BACKEND"] = "paddle" import pytest from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import RandomBatchSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -30,6 +31,7 @@ def generate_random_driver(features, labels): opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) driver = PaddleSingleDriver(model, device="cpu") driver.set_optimizers(opt) + driver.setup() return driver @@ -77,6 +79,7 @@ def test_save_and_load_state_dict(prepare_test_save_load): driver2.load_model(path) for batch in dataloader: + batch = driver1.move_data_to_device(batch) res1 = driver1.validate_step(batch) res2 = driver2.validate_step(batch) @@ -93,10 +96,11 @@ def test_save_and_load_whole_model(prepare_test_save_load): path = "model" driver1, driver2, dataloader = prepare_test_save_load - driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) + driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) driver2.load_model(path, only_state_dict=False) for batch in dataloader: + batch = driver1.move_data_to_device(batch) res1 = driver1.validate_step(batch) res2 = driver2.validate_step(batch) @@ -115,7 +119,7 @@ class TestSingleDeviceFunction: @classmethod def setup_class(cls): model = PaddleNormalModel_Classification_1(10, 784) - cls.driver = PaddleSingleDriver(model, device="gpu") + cls.driver = PaddleSingleDriver(model, device="cpu") def test_unwrap_model(self): """ @@ -130,22 +134,6 @@ class TestSingleDeviceFunction: self.driver.check_evaluator_mode("validate") self.driver.check_evaluator_mode("test") - def test_get_model_device_cpu(self): - """ - 测试get_model_device - """ - self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "cpu") - device = self.driver.get_model_device() - assert device == "cpu", device - - def test_get_model_device_gpu(self): - """ - 测试get_model_device - """ - self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "gpu:0") - device = self.driver.get_model_device() - assert device == "gpu:0", device - def test_is_distributed(self): assert self.driver.is_distributed() == False @@ -156,24 +144,24 @@ class TestSingleDeviceFunction: """ self.driver.move_data_to_device(paddle.rand((32, 64))) - @pytest.mark.parametrize( - "dist_sampler", [ - "dist", - ReproducibleBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), - RandomSampler(PaddleRandomMaxDataset(320, 10)) - ] - ) - @pytest.mark.parametrize( - "reproducible", - [True, False] - ) - def test_repalce_sampler(self, dist_sampler, reproducible): - """ - 测试set_dist_repro_dataloader函数 - """ - dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) - - res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) + # @pytest.mark.parametrize( + # "dist_sampler", [ + # "dist", + # RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), + # RandomSampler(PaddleRandomMaxDataset(320, 10)) + # ] + # ) + # @pytest.mark.parametrize( + # "reproducible", + # [True, False] + # ) + # def test_set_dist_repro_dataloader(self, dist_sampler, reproducible): + # """ + # 测试set_dist_repro_dataloader函数 + # """ + # dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) + + # res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) class TestPaddleDriverFunctions: """ @@ -183,7 +171,7 @@ class TestPaddleDriverFunctions: @classmethod def setup_class(self): model = PaddleNormalModel_Classification_1(10, 32) - self.driver = PaddleSingleDriver(model, device="gpu") + self.driver = PaddleSingleDriver(model, device="cpu") def test_check_single_optimizer_legality(self): """ @@ -198,7 +186,7 @@ class TestPaddleDriverFunctions: optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) # 传入torch的optimizer时,应该报错ValueError - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): self.driver.set_optimizers(optimizer) def test_check_optimizers_legality(self): @@ -218,7 +206,7 @@ class TestPaddleDriverFunctions: torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) ] - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): self.driver.set_optimizers(optimizers) def test_check_dataloader_legality_in_train(self): @@ -230,7 +218,7 @@ class TestPaddleDriverFunctions: # batch_size 和 batch_sampler 均为 None 的情形 dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) # 创建torch的dataloader @@ -238,7 +226,7 @@ class TestPaddleDriverFunctions: TorchNormalDataset(), batch_size=32, shuffle=True ) - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) def test_check_dataloader_legacy_in_test(self): @@ -257,11 +245,12 @@ class TestPaddleDriverFunctions: "train": paddle.io.DataLoader(PaddleNormalDataset()), "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) } - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + with pytest.raises(ValueError): + PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 传入的不是dict,应该报错 dataloader = paddle.io.DataLoader(PaddleNormalDataset()) - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 创建torch的dataloader @@ -274,7 +263,7 @@ class TestPaddleDriverFunctions: batch_size=32, shuffle=True ) dataloader = {"train": train_loader, "test": test_loader} - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) def test_tensor_to_numeric(self): @@ -284,25 +273,25 @@ class TestPaddleDriverFunctions: # 单个张量 tensor = paddle.to_tensor(3) res = PaddleSingleDriver.tensor_to_numeric(tensor) - self.assertEqual(res, 3) + assert res == 3 tensor = paddle.rand((3, 4)) res = PaddleSingleDriver.tensor_to_numeric(tensor) - self.assertListEqual(res, tensor.tolist()) + assert res == tensor.tolist() # 张量list tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] res = PaddleSingleDriver.tensor_to_numeric(tensor_list) - self.assertTrue(res, list) + assert isinstance(res, list) tensor_list = [t.tolist() for t in tensor_list] - self.assertListEqual(res, tensor_list) + assert res == tensor_list # 张量tuple tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) - self.assertTrue(res, tuple) + assert isinstance(res, tuple) tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) - self.assertTupleEqual(res, tensor_tuple) + assert res == tensor_tuple # 张量dict tensor_dict = { @@ -317,29 +306,29 @@ class TestPaddleDriverFunctions: } res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) - self.assertIsInstance(res, dict) - self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) - self.assertIsInstance(res["list"], list) + assert isinstance(res, dict) + assert res["tensor"] == tensor_dict["tensor"].tolist() + assert isinstance(res["list"], list) for r, d in zip(res["list"], tensor_dict["list"]): - self.assertListEqual(r, d.tolist()) - self.assertIsInstance(res["int"], int) - self.assertIsInstance(res["string"], str) - self.assertIsInstance(res["dict"], dict) - self.assertIsInstance(res["dict"]["list"], list) + assert r == d.tolist() + assert isinstance(res["int"], int) + assert isinstance(res["string"], str) + assert isinstance(res["dict"], dict) + assert isinstance(res["dict"]["list"], list) for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): - self.assertListEqual(r, d.tolist()) - self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) + assert r == d.tolist() + assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() def test_set_model_mode(self): """ 测试set_model_mode函数 """ self.driver.set_model_mode("train") - self.assertTrue(self.driver.model.training) + assert self.driver.model.training self.driver.set_model_mode("eval") - self.assertFalse(self.driver.model.training) + assert not self.driver.model.training # 应该报错 - with self.assertRaises(AssertionError) as cm: + with pytest.raises(AssertionError): self.driver.set_model_mode("test") def test_move_model_to_device_cpu(self): @@ -347,15 +336,15 @@ class TestPaddleDriverFunctions: 测试move_model_to_device函数 """ PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") - self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) + assert self.driver.model.linear1.weight.place.is_cpu_place() def test_move_model_to_device_gpu(self): """ 测试move_model_to_device函数 """ - PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu:0") - self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) - self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) + PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") + assert self.driver.model.linear1.weight.place.is_gpu_place() + assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 def test_worker_init_function(self): """ From 5419b6a04295ebd35ab221701757d7b1afeadb9c Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 12 Apr 2022 17:00:07 +0800 Subject: [PATCH 11/26] =?UTF-8?q?=E5=A1=AB=E4=BA=86=E4=BA=86=E5=85=B3?= =?UTF-8?q?=E9=97=AD=E5=8F=82=E6=95=B0=E5=8C=B9=E9=85=8D=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=9B=E6=B7=BB=E5=8A=A0=E4=BA=86=20trainer=20?= =?UTF-8?q?=E4=B8=AD=E8=8E=B7=E5=8F=96=20driver=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 6 ++- fastNLP/core/controllers/trainer.py | 28 ++++++++++--- fastNLP/core/drivers/torch_driver/ddp.py | 20 ++++----- .../drivers/torch_driver/single_device.py | 10 ++--- .../core/drivers/torch_driver/torch_driver.py | 7 +++- fastNLP/core/drivers/torch_driver/utils.py | 17 ++++---- fastNLP/core/utils/utils.py | 2 + .../test_checkpoint_callback_torch.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 41 ++++++++++++++++++- tests/helpers/models/torch_model.py | 27 ++++++++++++ 10 files changed, 125 insertions(+), 35 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 865acc89..b193f877 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -41,6 +41,7 @@ class Evaluator: mode: str = "validate", input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, + model_wo_auto_param_call: bool = False, fp16: Optional[bool] = False, verbose: int = 1, **kwargs @@ -61,6 +62,9 @@ class Evaluator: 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 + :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; + 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; :param fp16: 是否使用 fp16 。 :param verbose: 是否打印 evaluate 的结果。 :param kwargs: @@ -83,7 +87,7 @@ class Evaluator: self.model = model self.metrics = metrics - self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) + self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) self.device = device self.verbose = verbose diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d710f967..a7c38b27 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -47,6 +47,7 @@ class Trainer(TrainerEventTrigger): validate_every: Optional[Union[int, callable]] = -1, input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, + model_wo_auto_param_call: bool = False, accumulation_steps: int = 1, fp16: bool = False, marker: Optional[str] = None, @@ -99,7 +100,10 @@ class Trainer(TrainerEventTrigger): :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; - 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 + 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; + :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; + 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; @@ -120,9 +124,7 @@ class Trainer(TrainerEventTrigger): """ - # TODO 是不是可以加一个参数让用户现在关掉参数匹配。 self.marker = marker - self.model = model self.driver_name = driver self.device = device self.fp16 = fp16 @@ -164,6 +166,7 @@ class Trainer(TrainerEventTrigger): validate_every=validate_every, input_mapping=input_mapping, output_mapping=output_mapping, + model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, marker=marker, @@ -484,8 +487,6 @@ class Trainer(TrainerEventTrigger): @driver.setter def driver(self, driver: Driver): - driver.trainer = self - driver.model = self.model self._driver = driver @property @@ -782,4 +783,21 @@ class Trainer(TrainerEventTrigger): def total_batches(self, total_batches: int): self.trainer_state.total_batches = total_batches + """ driver property """ + + @property + def model_device(self): + return self.driver.model_device + + @property + def data_device(self): + return self.driver.data_device + + @property + def model(self): + # 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; + return self.driver.model + + + diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 44cabcf4..4cf207cd 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -167,6 +167,7 @@ class TorchDDPDriver(TorchDriver): 不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; """ + # 在加入很多东西后,需要注意这里调用 super 函数的位置; super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) if isinstance(model, torch.nn.DataParallel): @@ -202,8 +203,8 @@ class TorchDDPDriver(TorchDriver): # 我们就直接将 model_device 置为 None; self.model_device = None - def _running_fn_(batch, step_fn, signature_fn): - if isinstance(batch, Dict): + def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(step_fn, batch, signature_fn=signature_fn) else: return step_fn(batch) @@ -214,7 +215,7 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `train_step` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._train_signature_fn = model.forward if hasattr(model, "validate_step"): @@ -222,7 +223,7 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `validate_step` method, which we can not call actually, " "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._validate_signature_fn = model.forward if hasattr(model, "test_step"): @@ -230,14 +231,11 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `test_step` method, which we can not call actually, we will" " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._test_signature_fn = model.forward # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) - # if self.outside_ddp and self._data_device is None: - # raise RuntimeError("When you initialize your ddp out of our control, the parameter " - # "`data_device` can not be None.") if isinstance(self._data_device, int): if self._data_device < 0: raise ValueError("Parameter `data_device` can not be smaller than 0.") @@ -349,9 +347,9 @@ class TorchDDPDriver(TorchDriver): **self._ddp_kwargs ) - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) + self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) + self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) + self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) self._configured = True diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 19e687b8..8cbb7acd 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,7 +13,7 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler from fastNLP.core.log import logger @@ -102,7 +102,7 @@ class TorchSingleDriver(TorchDriver): def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) @@ -118,13 +118,13 @@ class TorchSingleDriver(TorchDriver): def validate_step(self, batch) -> Dict: # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) def test_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) @@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index b200f1fd..d2ffbac1 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler class TorchDriver(Driver): @@ -51,6 +51,9 @@ class TorchDriver(Driver): # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; self.non_blocking = kwargs.get("torch_non_blocking", True) + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + def zero_grad(self, set_to_none: bool = False): for optimizer in self.optimizers: self._clear_grad(optimizer, set_to_none) @@ -252,7 +255,7 @@ class TorchDriver(Driver): elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: - sampler = ReproducibleBatchSampler( + sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 406e030b..4210dac5 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -140,24 +140,25 @@ class _DDPWrappingModel(Module): pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; """ - _forward_state = kwargs.pop(_MODE_PARAMETER) + forward_state = kwargs.pop(_MODE_PARAMETER) + wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if _forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict): + if forward_state == ForwardState.TRAIN: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) - elif _forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict): + elif forward_state == ForwardState.VALIDATE: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) - elif _forward_state == ForwardState.TEST: - if isinstance(batch, Dict): + elif forward_state == ForwardState.TEST: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) - elif _forward_state == ForwardState.PREDICT: + elif forward_state == ForwardState.PREDICT: raise NotImplementedError("'PREDICT' mode has not been implemented.") else: raise NotImplementedError("You should direct a concrete mode.") diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 0d497bc2..5c497606 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -96,6 +96,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 参数值后,再传给 `fn` 进行实际的运算; :param mapping: 一个字典,用来更改其前面的字典的键值; + :param wo_auto_param_call: 是否关闭默认的参数匹配行为; :return: 返回 `fn` 运行的结果; @@ -113,6 +114,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 """ + if signature_fn is not None: if not callable(signature_fn): raise ValueError(f"Parameter `signature_fn` should be `Callable`.") diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 1f404bb8..557c31b2 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -10,7 +10,7 @@ import re from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from fastNLP.core.controllers.trainer import Trainer -from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK +from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context from fastNLP.core import synchronize_safe_rm diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 0a280a0c..0da8c976 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -10,7 +10,7 @@ from typing import Any from pathlib import Path from fastNLP.core.controllers.trainer import Trainer -from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch @@ -70,7 +70,7 @@ def model_and_optimizers(request): trainer_params.output_mapping = None # elif request.param == 1: - # model = + return trainer_params @@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): subprocess.check_call(command) +@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) +@magic_argv_env_context +def test_torch_wo_auto_param_call( + driver, + device, + n_epochs=10, +): + + model = TorchNormalModel_Classification_3( + num_labels=NormalClassificationTrainTorchConfig.num_labels, + feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension + ) + optimizers = SGD(model.parameters(), lr=0.001) + dataset = TorchNormalDataset_Classification( + num_labels=NormalClassificationTrainTorchConfig.num_labels, + feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, + each_label_data=NormalClassificationTrainTorchConfig.each_label_data, + seed=NormalClassificationTrainTorchConfig.seed + ) + train_dataloader = DataLoader( + dataset=dataset, + batch_size=NormalClassificationTrainTorchConfig.batch_size, + shuffle=True + ) + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + n_epochs=n_epochs, + model_wo_auto_param_call=True, + output_from_new_proc="all" + ) + trainer.run() + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 2912224f..b949a26f 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): x = torch.max(x, dim=-1)[1] return {"preds": x, "target": y} + class TorchNormalModel_Classification_2(nn.Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; @@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} +class TorchNormalModel_Classification_3(nn.Module): + """ + 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; + 关闭 auto_param_call,forward 只有一个 batch 参数; + """ + def __init__(self, num_labels, feature_dimension): + super(TorchNormalModel_Classification_3, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=10, out_features=10) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=10, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, batch): + x = batch["x"] + y = batch["y"] + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + loss = self.loss_fn(x, y) + x = torch.max(x, dim=-1)[1] + return {"loss": loss, "preds": x, "target": y} + From a5b2ccf7590dc9fb5149a7697ceeae61a4c839b8 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 12 Apr 2022 09:36:51 +0000 Subject: [PATCH 12/26] update --- fastNLP/core/drivers/paddle_driver/fleet.py | 18 ++++++++---------- .../drivers/paddle_driver/paddle_driver.py | 7 +++++-- .../drivers/paddle_driver/single_device.py | 15 ++++++++++----- fastNLP/core/drivers/paddle_driver/utils.py | 19 ++++++++++--------- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 86198959..582ce542 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -104,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): # 我们就直接将 model_device 置为 None; self._model_device = None - def _running_fn_(batch, step_fn, signature_fn): - if isinstance(batch, Dict): + def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(step_fn, batch, signature_fn=signature_fn) else: return self._validate_step(batch) @@ -116,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `train_step` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) - # self._train_signature_fn = model.forward + self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) if hasattr(model, "validate_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `validate_step` method, which we can not call actually, " "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) - # self._validate_signature_fn = model.forward + self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) if hasattr(model, "test_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `test_step` method, which we can not call actually, we will" " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) @@ -277,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver): **self._fleet_kwargs ) - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) + self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) + self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) + self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) self._configured = True diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 89e88aef..a407a7b7 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,7 +19,7 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler if _NEED_IMPORT_PADDLE: import paddle @@ -56,6 +56,9 @@ class PaddleDriver(Driver): self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + def zero_grad(self, set_to_none: bool = False): r""" 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; @@ -301,7 +304,7 @@ class PaddleDriver(Driver): elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: - sampler = ReproducibleBatchSampler( + sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 64656124..796f4809 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -11,7 +11,12 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler +from fastNLP.core.samplers import ( + ReproducibleBatchSampler, + RandomBatchSampler, + ReproducibleSampler, + re_instantiate_sampler, +) from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -102,7 +107,7 @@ class PaddleSingleDriver(PaddleDriver): def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) @@ -116,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): self.grad_scaler.update() def validate_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) def test_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) @@ -159,7 +164,7 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 47c0f1b9..36982b4c 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -85,7 +85,7 @@ class ForwardState(IntEnum): TEST = 2 PREDICT = 3 -_MODE_PARAMETER = "_forward_state" +_MODE_PARAMETER = "forward_state" class _FleetWrappingModel(Layer): """ @@ -151,24 +151,25 @@ class _FleetWrappingModel(Layer): def forward(self, batch, **kwargs) -> Dict: - _forward_state = kwargs.pop(_MODE_PARAMETER) + forward_state = kwargs.pop(_MODE_PARAMETER) + wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if _forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict): + if forward_state == ForwardState.TRAIN: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) - elif _forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict): + elif forward_state == ForwardState.VALIDATE: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) - elif _forward_state == ForwardState.TEST: - if isinstance(batch, Dict): + elif forward_state == ForwardState.TEST: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) - elif _forward_state == ForwardState.PREDICT: + elif forward_state == ForwardState.PREDICT: raise NotImplementedError("'PREDICT' mode has not been implemented.") else: raise NotImplementedError("You should direct a concrete mode.") From 8c22d0b1f61101fa4d32888367e80abfa1923318 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 12 Apr 2022 22:47:39 +0800 Subject: [PATCH 13/26] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20Trainer.on?= =?UTF-8?q?=20=E7=9A=84=E9=94=99=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 2 +- fastNLP/core/controllers/trainer.py | 20 ++++++++------- fastNLP/core/utils/utils.py | 4 --- tests/core/callbacks/test_callback_events.py | 2 +- .../controllers/test_trainer_other_things.py | 25 +++++++++++++++++++ 5 files changed, 38 insertions(+), 15 deletions(-) create mode 100644 tests/core/controllers/test_trainer_other_things.py diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 4b553a1f..99e47dfe 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -71,7 +71,7 @@ class Callback: """ pass - def on_train_batch_begin(self, trainer, batch, indices=None): + def on_train_batch_begin(self, trainer, batch, indices): r""" 在训练过程中开始具体的一个 batch 前会被触发; diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index af589cbf..6d154770 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger): auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 """ - + self.model = model self.marker = marker - self.driver_name = driver + if isinstance(driver, str): + self.driver_name = driver + else: + self.driver_name = driver.__class__.__name__ self.device = device self.fp16 = fp16 self.input_mapping = input_mapping @@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger): elif accumulation_steps < 0: raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") self.accumulation_steps = accumulation_steps + + # todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧 self.driver = choose_driver( model=model, driver=driver, @@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger): def wrapper(fn: Callable) -> Callable: cls._custom_callbacks[marker].append((event, fn)) - assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ - "callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ - "`Callback` with the same callback time." + callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] + assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ + f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ + f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." return fn return wrapper @@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger): def data_device(self): return self.driver.data_device - @property - def model(self): - # 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; - return self.driver.model diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 5c497606..46211581 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -44,15 +44,11 @@ __all__ = [ ] - - - def get_fn_arg_names(fn: Callable) -> List[str]: r""" 返回一个函数的所有参数的名字; :param fn: 需要查询的函数; - :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; """ return list(inspect.signature(fn).parameters) diff --git a/tests/core/callbacks/test_callback_events.py b/tests/core/callbacks/test_callback_events.py index a71bb07f..8712b469 100644 --- a/tests/core/callbacks/test_callback_events.py +++ b/tests/core/callbacks/test_callback_events.py @@ -1,7 +1,7 @@ import pytest from functools import reduce -from fastNLP.core.callbacks.callback_events import Filter +from fastNLP.core.callbacks.callback_events import Events, Filter class TestFilter: diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py new file mode 100644 index 00000000..6327f4f8 --- /dev/null +++ b/tests/core/controllers/test_trainer_other_things.py @@ -0,0 +1,25 @@ +import pytest + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.callbacks import Events +from tests.helpers.utils import magic_argv_env_context + + +@magic_argv_env_context +def test_trainer_torch_without_evaluator(): + @Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10)) + def fn1(trainer): + pass + + @Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) + def fn2(trainer, batch, indices): + pass + + with pytest.raises(AssertionError): + @Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) + def fn3(trainer, batch): + pass + + + + From e8d11cd5a9ec53bd2c6da16e1c14cb025e489415 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 13 Apr 2022 12:55:28 +0800 Subject: [PATCH 14/26] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8Dtorch=20=E5=88=86?= =?UTF-8?q?=E5=B8=83=E5=BC=8F=E5=9C=A8=E4=B8=8D=E5=90=8C=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E4=B8=ADgroup=E5=8F=82=E6=95=B0default=E5=80=BC=E4=B8=8D?= =?UTF-8?q?=E4=B8=80=E6=A0=B7=E7=9A=84=E9=97=AE=E9=A2=98;=202.=20torch?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E5=8D=A1=E6=97=B6=E5=8F=AA=E6=9C=89?= =?UTF-8?q?batchsampler=20evaluate=E4=BC=9A=E9=81=87=E5=88=B0bug=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98;=203=E3=80=82logger=E5=A2=9E=E5=8A=A0warning?= =?UTF-8?q?=5Fonce=E6=8E=A5=E5=8F=A3;4.=E5=A2=9E=E5=8A=A0callback=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 142 ++++++++++++++++-- fastNLP/core/callbacks/checkpoint_callback.py | 19 +-- fastNLP/core/collators/collator.py | 2 +- fastNLP/core/controllers/evaluator.py | 19 +-- .../controllers/loops/train_batch_loop.py | 9 +- fastNLP/core/controllers/trainer.py | 2 +- .../core/dataloaders/torch_dataloader/fdl.py | 2 +- fastNLP/core/dataset/dataset.py | 6 +- fastNLP/core/drivers/torch_driver/ddp.py | 4 +- .../core/drivers/torch_driver/dist_utils.py | 61 ++++---- fastNLP/core/drivers/torch_driver/utils.py | 13 +- fastNLP/core/log/logger.py | 16 ++ .../samplers/reproducible_batch_sampler.py | 4 +- fastNLP/core/samplers/reproducible_sampler.py | 4 +- fastNLP/core/samplers/unrepeated_sampler.py | 4 +- fastNLP/core/utils/rich_progress.py | 79 +++++++++- fastNLP/core/utils/torch_utils.py | 6 +- fastNLP/core/utils/utils.py | 5 +- tests/core/log/test_logger.py | 48 +++--- tests/helpers/utils.py | 15 ++ 20 files changed, 339 insertions(+), 121 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 99e47dfe..96e4372b 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -32,100 +32,205 @@ class Callback: def on_sanity_check_end(self, trainer, sanity_check_res): r""" 在 '预跑'检测 开始后会被触发; + + :param trainer: + :param sanity_check_res: 预跑的 evaluate 结果 + :return: """ pass def on_train_begin(self, trainer): r""" 在训练开始前会被触发; + + :param trainer: + :return: """ pass def on_train_end(self, trainer): r""" 在训练完成后会被触发; + + :param trainer: + :return: """ pass def on_train_epoch_begin(self, trainer): r""" 在训练过程中的每一个 epoch 开始前会被触发; + + :param trainer: + :return: """ pass def on_train_epoch_end(self, trainer): r""" - 在训练过程中的每一个 epoch 完成后会被触发; + 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 + + :param trainer: + :return: """ pass def on_fetch_data_begin(self, trainer): r""" - 在训练过程中拿到当前的具体的一个 batch 前会被触发; + 在训练过程中准备取出下一个 batch 的数据时触发 + + :param trainer: + :return: """ pass def on_fetch_data_end(self, trainer): r""" - 在训练过程中拿到当前的具体的一个 batch 后会被触发; + 在训练过程中拿到当前的 batch 数据后会被触发; + + :param trainer: + :return: """ pass def on_train_batch_begin(self, trainer, batch, indices): r""" - 在训练过程中开始具体的一个 batch 前会被触发; + 在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 + 其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 + 如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 :param trainer: `fastNLP.Trainer` - :param batch: 当前正在运行的一个 batch; - :param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; + :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 + :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 """ pass def on_train_batch_end(self, trainer): + """ + 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 + global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 + 执行。 + + :param trainer: + :return: + """ pass def on_exception(self, trainer, exception): + """ + 在训练过程遇到异常时调用。 + + :param trainer: + :param exception: 遭遇的异常。 + :return: + """ pass def on_save_model(self, trainer): + """ + 当将要保存模型时调用,此刻模型还未保存。 + + :param trainer: + :return: + """ pass def on_load_model(self, trainer): + """ + 当将要加载模型时调用,此刻模型还未加载。 + + :param trainer: + :return: + """ pass def on_save_checkpoint(self, trainer) -> Dict: """ - 当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 - 工作的状态;而不应该让该函数去判断两个 callback 是否一样; + 当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 + + :param trainer: + :return: """ pass def on_load_checkpoint(self, trainer, states: Optional[Dict]): r""" - 如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; + 当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() + 的返回值。 + + :param trainer: + :param states: + :return: """ pass def on_before_backward(self, trainer, outputs): + """ + 在 backward 前执行。 + + :param trainer: + :param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 + :return: + """ pass def on_after_backward(self, trainer): + """ + 在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, + 因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 + + :param trainer: + :return: + """ pass def on_before_optimizer_step(self, trainer, optimizers): + """ + 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ pass def on_before_zero_grad(self, trainer, optimizers): + """ + 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ pass def on_validate_begin(self, trainer): + """ + 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 + 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 + + :param trainer: + :return: + """ pass def on_validate_end(self, trainer, results): + """ + 结束 validate 时调用,并把 validate 的结果传入。 + + :param trainer: + :param results: + :return: + """ pass @property def callback_name(self): + """ + callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 + + :return: + """ return self.__class__.__name__ @@ -226,10 +331,21 @@ class HasMonitorCallback(Callback): :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 :return: """ + better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) + if keep_if_better and better: + self.monitor_value = monitor_value + return better + + def is_former_monitor_value_better(self, monitor_value1, monitor_value2): + """ + 传入的两个值中,是否monitor_value1的结果更好。 + + :param monitor_value1: + :param monitor_value2: + :return: + """ better = False - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + if (self.larger_better and monitor_value1 > monitor_value2) or \ + (not self.larger_better and monitor_value1 < monitor_value2): better = True - if keep_if_better: - self.monitor_value = monitor_value return better \ No newline at end of file diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 839a9522..82bfe404 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir -from fastNLP.core.utils import apply_to_collection class CheckpointCallback(HasMonitorCallback): @@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback): else: _least_valuable_model = (min if self.larger_better else max)(self._topk_model, key=lambda x: self._topk_model[x]) - if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ - (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): + if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): self._topk_model[folder_name] = monitor_value _should_save = True self._topk_model.pop(_least_valuable_model) @@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback): **self.kwargs ) - def _get_validate_metric(self, res: Dict): - """ - 该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); - 如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; - :param res: - :return: - """ - use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) - if self._real_monitor != use_monitor: - logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " - f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") - self._real_monitor = use_monitor - - return value - @property def folder_prefix(self): raise NotImplementedError("The `folder_prefix` is not specified") diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 78b07751..f468dd4c 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -197,7 +197,7 @@ class _MultiCollator: collator.set_input(*field_names) flag = False if flag: - warnings.warn("AutoCollator is remove, set_input is unavailable!!") + warnings.warn("AutoCollator is removed, set_input is unavailable!!") return self diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index b193f877..479686e1 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -223,7 +223,6 @@ class Evaluator: def remove_progress_bar(self, dataloader_name): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) - f_rich_progress.refresh() # 使得最终的bar可以消失 delattr(self, '_rich_task_id') elif self.progress_bar == 'raw': desc = 'Evaluation ends' @@ -234,7 +233,6 @@ class Evaluator: def finally_progress_bar(self): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) - f_rich_progress.refresh() delattr(self, '_rich_task_id') @property @@ -359,20 +357,23 @@ class _MetricsWrapper: if is_dataclass(outputs): outputs = dataclass_to_dict(outputs) for metric in self._metrics: + args = [] if not isinstance(batch, dict): - raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" - f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") + logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " + f"the output of model to update metric.") + else: + args.append(batch) if not isinstance(outputs, dict): - raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" + raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, batch, outputs) + auto_param_call(metric.update, batch, *args) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, batch, outputs) + auto_param_call(metric.update, batch, *args) elif _is_allennlp_metric(metric): - auto_param_call(metric.__call__, batch, outputs) + auto_param_call(metric.__call__, batch, *args) elif _is_paddle_metric(metric): - res = auto_param_call(metric.compute, batch, outputs) + res = auto_param_call(metric.compute, batch, *args) metric.update(res) def reset(self): diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 5d127359..a3219e6d 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -7,6 +7,7 @@ from typing import Optional, Callable from .loop import Loop from fastNLP.core.log import logger from fastNLP.core.utils import match_and_substitute_params +from fastNLP.core.utils.exceptions import EarlyStopException class TrainBatchLoop(Loop): @@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): try: trainer.on_fetch_data_begin() batch = next(dataloader) - batch = match_and_substitute_params(trainer.input_mapping, batch) indices = get_batch_indices() - batch = trainer.move_data_to_device(batch) trainer.on_fetch_data_end() + batch = match_and_substitute_params(trainer.input_mapping, batch) + batch = trainer.move_data_to_device(batch) except StopIteration: break - except BaseException as e: # TODO 把这里的信息写入进去 + except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception + break + except BaseException as e: if indices: logger.debug(f"The following exception happens when running on samples: {indices}") raise e diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 6d154770..5daee856 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger): self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') # 5. 恢复所有 callback 的状态; - self.on_load_checkpoint(states["callback_states"]) + self.train_stepeckpoint(states["callback_states"]) self.driver.barrier() diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index d56dbac9..13eae93c 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, as_numpy: bool = False) -> None: + persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: """ :param dataset: 实现了__getitem__和__len__的数据容器 diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 9630a3a0..5b8ec635 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -788,13 +788,14 @@ class DataSet: def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: """ - 设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 + 设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 当val=None时,意味着给定的field_names都不需要尝试padding :param field_names: dataset存在的field_name - :param val: 默认为0 + :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 :return: """ + # TODO 需要去重复 for field_name in field_names: self.collate_fns.set_pad_val(field_name, val=val) @@ -805,6 +806,7 @@ class DataSet: :param field_names: :return: """ + # self.collate_fns.set_input(*field_names) def get_collator(self) -> _MultiCollator: diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 4cf207cd..3537d0b3 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel + from torch.utils.data import BatchSampler __all__ = [ 'TorchDDPDriver' @@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver): num_replicas=self.world_size, rank=self.global_rank ) - return replace_sampler(dataloader, sampler) + batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) + return replace_batch_sampler(dataloader, batch_sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 5e3819e7..ad9e6794 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -3,28 +3,20 @@ import pickle _pickler = pickle.Pickler _unpickler = pickle.Unpickler from typing import Any, List -from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 - +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 +from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist - try: - from torch._C._distributed_c10d import ProcessGroupMPI - except ImportError: - _MPI_AVAILABLE = False - - try: - from torch._C._distributed_c10d import ProcessGroupNCCL - except ImportError: - _NCCL_AVAILABLE = False - - try: - from torch._C._distributed_c10d import ProcessGroupGloo - from torch._C._distributed_c10d import _ProcessGroupWrapper - except ImportError: - _GLOO_AVAILABLE = False + if _TORCH_GREATER_EQUAL_1_8: + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + pass + from fastNLP.core.utils import apply_to_collection @@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): +def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): """ 从其它 rank gather 东西到 dst rank 。 @@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) >>> output ['foo', 12, {1: 2}] """ + if group is None: + group = DEFAULT_TORCH_GROUP + if dist.distributed_c10d._rank_not_in_group(group): return @@ -193,7 +188,7 @@ def _to_device(tensor, device): return tensor.contiguous().to(device) -def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: +def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 @@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ - # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 + if group is None: + group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] dist.all_gather(objs, obj, group=group) @@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: return objs -def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): +def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): """ 将 src 上的 obj 对象广播到其它 rank 上。 @@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): :param group: :return: """ + if group is None: + group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) if cur_rank == src: # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 @@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None): return input_tensor, local_size = _object_to_tensor(obj) - current_device = torch.device("cpu") - is_nccl_backend = _check_for_nccl_backend(group) - if is_nccl_backend: - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) + if _TORCH_GREATER_EQUAL_1_8: + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + else: + current_device = torch.cuda.current_device() # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = dist.get_world_size(group=group) diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 4210dac5..cdc6cea9 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -8,6 +8,7 @@ import numpy as np import inspect from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.samplers import re_instantiate_sampler if _NEED_IMPORT_TORCH: import torch @@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): "manually add the `DistributedSampler` as: " f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." ) - return type(dataloader)(**reconstruct_args) @@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( """ batch_sampler = getattr(dataloader, "batch_sampler") # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - batch_sampler = type(batch_sampler)( - sampler, - batch_size=batch_sampler.batch_size, - drop_last=batch_sampler.drop_last, - ) + if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): + batch_sampler = re_instantiate_sampler(batch_sampler) return { "sampler": None, @@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): params = {k: getattr(dataloader, k) for k in params_keys} params["batch_sampler"] = new_batch_sampler return type(dataloader)(**params) + # TODO 这里是否可以auto_param_call一下 + # return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, + # signature_fn=type(dataloader).__init__) def optimizer_state_to_device(state, device): diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index ae89ad3f..9763ab4a 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -51,6 +51,7 @@ class LoggerSingleton(type): class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): def __init__(self, name): super().__init__(name) + self._warning_msgs = set() def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, mode: str = "w"): @@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): kwargs = self._add_rank_info(kwargs) self._log(WARNING, msg, args, **kwargs) + def warning_once(self, msg, *args, **kwargs): + """ + 通过 warning 内容只会 warning 一次 + + :param msg: + :param args: + :param kwargs: + :return: + """ + if msg not in self._warning_msgs: + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + self._warning_msgs.add(msg) + def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " "use 'warning' instead", DeprecationWarning, 2) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index c4116e24..d1041f08 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ super().__init__() - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 1dc226a5..f48e2fc6 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler): :param kwargs: fastNLP 保留使用 """ super().__init__(dataset=dataset, **kwargs) - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index d7913d20..02ec1162 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): :param kwargs: fastNLP 保留使用 """ super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index a865f4c1..82747a01 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -6,7 +6,7 @@ import sys from typing import Any, Union, Optional -from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live +from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn __all__ = [ @@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): if task_id in self._tasks: super().stop_task(task_id) super().remove_task(task_id) + self.refresh() # 使得bar不残留 def start(self) -> None: super().start() self.console.show_cursor(show=True) + def update( + self, + task_id: TaskID, + *, + total: Optional[float] = None, + completed: Optional[float] = None, + advance: Optional[float] = None, + description: Optional[str] = None, + visible: Optional[bool] = None, + refresh: bool = False, + **fields: Any, + ) -> None: + """Update information associated with a task. + + Args: + task_id (TaskID): Task id (returned by add_task). + total (float, optional): Updates task.total if not None. + completed (float, optional): Updates task.completed if not None. + advance (float, optional): Add a value to task.completed if not None. + description (str, optional): Change task description if not None. + visible (bool, optional): Set visible flag if not None. + refresh (bool): Force a refresh of progress information. Default is False. + **fields (Any): Additional data fields required for rendering. + """ + with self._lock: + task = self._tasks[task_id] + completed_start = task.completed + + if total is not None and total != task.total: + task.total = total + task._reset() + if advance is not None: + task.completed += advance + if completed is not None: + task.completed = completed + if description is not None: + task.description = description + if visible is not None: + task.visible = visible + task.fields.update(fields) + update_completed = task.completed - completed_start + + current_time = self.get_time() + old_sample_time = current_time - self.speed_estimate_period + _progress = task._progress + + popleft = _progress.popleft + # 这里修改为至少保留一个,防止超长时间的迭代影响判断 + while len(_progress)>1 and _progress[0].timestamp < old_sample_time: + popleft() + if update_completed > 0: + _progress.append(ProgressSample(current_time, update_completed)) + if task.completed >= task.total and task.finished_time is None: + task.finished_time = task.elapsed + + if refresh: + self.refresh() + + +class SpeedColumn(ProgressColumn): + """ + 显示 task 的速度。 + + """ + def render(self, task: "Task"): + speed = task.speed + if speed is None: + return Text('-- it./s', style='progress.data.speed') + if speed > 0.1: + return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') + else: + return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') + if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", "[progress.percentage]{task.percentage:>3.0f}%", BarColumn(), + SpeedColumn(), TimeElapsedColumn(), "/", TimeRemainingColumn(), TextColumn("{task.fields[post_desc]}", justify="right"), transient=True, disable=False, - speed_estimate_period=1 + speed_estimate_period=30 ) else: f_rich_progress = DummyFRichProgress() diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py index 9dea93dd..2dfc0802 100644 --- a/fastNLP/core/utils/torch_utils.py +++ b/fastNLP/core/utils/torch_utils.py @@ -1,9 +1,11 @@ from abc import ABC from typing import Any, Union, Optional -from fastNLP.envs.imports import _NEED_IMPORT_TORCH - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 +DEFAULT_TORCH_GROUP = None if _NEED_IMPORT_TORCH: import torch + if not _TORCH_GREATER_EQUAL_1_8: + DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD __all__ = [ 'torch_move_data_to_device' diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 46211581..c402fe11 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" - 1.该函数用来提供给用户根据字符串匹配从而实现自动计算; + 该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping + 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 + + 1.该函数用来提供给用户根据字符串匹配从而实现自动调用; 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger.py index da9b7b6b..4fe49bef 100644 --- a/tests/core/log/test_logger.py +++ b/tests/core/log/test_logger.py @@ -6,13 +6,16 @@ import logging import re from fastNLP.envs.env import FASTNLP_LAUNCH_TIME -from tests.helpers.utils import magic_argv_env_context from fastNLP.core import synchronize_safe_rm +from fastNLP.core.log.logger import logger + +from tests.helpers.utils import magic_argv_env_context, recover_logger # 测试 TorchDDPDriver; @magic_argv_env_context -def test_add_file_ddp_1(): +@recover_logger +def test_add_file_ddp_1_torch(): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; @@ -56,11 +59,11 @@ def test_add_file_ddp_1(): synchronize_safe_rm(filepath) dist.barrier() dist.destroy_process_group() - logger.removeHandler(handler) @magic_argv_env_context -def test_add_file_ddp_2(): +@recover_logger +def test_add_file_ddp_2_torch(): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; """ @@ -103,14 +106,14 @@ def test_add_file_ddp_2(): assert len(pattern.findall(line)) == 1 finally: synchronize_safe_rm(path) - logger.removeHandler(handler) dist.barrier() dist.destroy_process_group() @magic_argv_env_context -def test_add_file_ddp_3(): +@recover_logger +def test_add_file_ddp_3_torch(): """ path = None; @@ -155,10 +158,10 @@ def test_add_file_ddp_3(): synchronize_safe_rm(file) dist.barrier() dist.destroy_process_group() - logger.removeHandler(handler) @magic_argv_env_context -def test_add_file_ddp_4(): +@recover_logger +def test_add_file_ddp_4_torch(): """ 测试 path 是文件夹; """ @@ -200,7 +203,6 @@ def test_add_file_ddp_4(): assert len(pattern.findall(line)) == 1 finally: synchronize_safe_rm(path) - logger.removeHandler(handler) dist.barrier() dist.destroy_process_group() @@ -209,12 +211,11 @@ def test_add_file_ddp_4(): class TestLogger: msg = 'some test log msg' + @recover_logger def test_add_file_1(self): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; """ - from fastNLP.core.log.logger import logger - path = Path(tempfile.mkdtemp()) try: filepath = path.joinpath('log.txt') @@ -225,14 +226,12 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(path) - logger.removeHandler(handler) + @recover_logger def test_add_file_2(self): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; """ - from fastNLP.core.log.logger import logger - origin_path = Path(tempfile.mkdtemp()) try: @@ -245,14 +244,12 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(origin_path) - logger.removeHandler(handler) + @recover_logger def test_add_file_3(self): """ 测试 path 是 None; """ - from fastNLP.core.log.logger import logger - handler = logger.add_file() logger.info(self.msg) @@ -264,14 +261,12 @@ class TestLogger: line = ''.join([l for l in f]) assert self.msg in line file.unlink() - logger.removeHandler(handler) + @recover_logger def test_add_file_4(self): """ 测试 path 是文件夹; """ - from fastNLP.core.log.logger import logger - path = Path(tempfile.mkdtemp()) try: handler = logger.add_file(path) @@ -285,16 +280,21 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(path) - logger.removeHandler(handler) + @recover_logger def test_stdout(self, capsys): - from fastNLP.core.log.logger import logger - handler = logger.set_stdout(stdout="raw") logger.info(self.msg) logger.debug('aabbc') captured = capsys.readouterr() assert "some test log msg\n" == captured.out - logger.removeHandler(handler) + @recover_logger + def test_warning_once(self, capsys): + logger.warning_once('#') + logger.warning_once('#') + logger.warning_once('@') + captured = capsys.readouterr() + assert captured.out.count('#') == 1 + assert captured.out.count('@') == 1 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f4effc1f..b876c289 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -13,6 +13,7 @@ import numpy as np from fastNLP.envs.env import FASTNLP_GLOBAL_RANK from fastNLP.core.drivers.utils import distributed_open_proc +from fastNLP.core.log import logger def get_class_that_defined_method(meth): @@ -32,6 +33,20 @@ def get_class_that_defined_method(meth): return getattr(meth, '__objclass__', None) # handle special descriptor objects +def recover_logger(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + # 保存logger的状态 + handlers = [handler for handler in logger.handlers] + level = logger.level + res = fn(*args, **kwargs) + logger.handlers = handlers + logger.setLevel(level) + return res + + return wrapper + + def magic_argv_env_context(fn): @wraps(fn) From 76a1e69022eda7356dd118635db4580915e4c286 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 13 Apr 2022 14:27:01 +0800 Subject: [PATCH 15/26] little change --- fastNLP/core/controllers/trainer.py | 4 ++-- tests/core/controllers/test_trainer_w_evaluator_torch.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 6d154770..24c9f2b7 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; - 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 - 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; + 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 True,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 8944e45d..699ee3b9 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -143,7 +143,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( accumulation_steps, n_epochs=6, ): - callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, From 3ab93b2fae6ae60a417cc386a24230222d843b2d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 13 Apr 2022 07:33:27 +0000 Subject: [PATCH 16/26] =?UTF-8?q?paddle=20driver=E5=8D=95=E5=8D=A1?= =?UTF-8?q?=E5=92=8Cutils=E7=9A=84pytest=E6=B5=8B=E8=AF=95=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 18 +- fastNLP/core/drivers/paddle_driver/utils.py | 16 +- .../paddle_driver/test_single_device.py | 328 ++++++++++++++++-- .../core/drivers/paddle_driver/test_utils.py | 56 ++- 4 files changed, 367 insertions(+), 51 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index a407a7b7..4362dcce 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -191,6 +191,8 @@ class PaddleDriver(Driver): :return: """ model = self.unwrap_model() + if isinstance(filepath, Path): + filepath = str(filepath) if only_state_dict: states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} paddle.save(states, filepath) @@ -211,6 +213,8 @@ class PaddleDriver(Driver): :return: """ model = self.unwrap_model() + if isinstance(filepath, Path): + filepath = str(filepath) # paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict # 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 dirname, filename = os.path.split(filepath) @@ -274,11 +278,11 @@ class PaddleDriver(Driver): logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict - paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: - states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) # 1. 加载 optimizers 的状态; optimizers_state_dict = states["optimizers_state_dict"] @@ -435,6 +439,16 @@ class PaddleDriver(Driver): res.shuffle = True else: res.shuffle = False + # RandomBatchSampler 的情况 + elif hasattr(dataloader.batch_sampler, "batch_sampler"): + batch_sampler = dataloader.batch_sampler.batch_sampler + res.sampler = batch_sampler.sampler + if hasattr(batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(batch_sampler.sampler, RandomSampler): + res.shuffle = True + else: + res.shuffle = False else: res.sampler = None res.shuffle = False diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 36982b4c..895ec703 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -4,12 +4,14 @@ import struct import random import inspect import numpy as np +from copy import deepcopy from contextlib import ExitStack, closing from enum import IntEnum from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to +from fastNLP.core.samplers import RandomSampler from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler + from paddle.io import DataLoader, BatchSampler, Dataset from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -206,7 +208,6 @@ class DummyGradScaler: def state_dict(self): return {} - def _build_fp16_env(dummy=False): if dummy: auto_cast = ExitStack @@ -260,7 +261,7 @@ def get_device_from_visible(device: Union[str, int], output_type=int): """ 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 - :param devices: 未转化的设备名 + :param device: 未转化的设备名 :param output_type: 返回值的类型 :return: 转化后的设备id """ @@ -365,13 +366,8 @@ def replace_sampler(dataloader, new_sampler): """ 使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 """ - new_batch_sampler = BatchSampler( - dataset=dataloader.batch_sampler.dataset, - sampler=new_sampler, - shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler), - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.batch_sampler.drop_last - ) + new_batch_sampler = deepcopy(dataloader.batch_sampler) + new_batch_sampler.sampler = new_sampler return replace_batch_sampler(dataloader, new_batch_sampler) def optimizer_state_to_device(state, device): diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 3d07766a..b9681121 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,11 +1,10 @@ import os -from numpy import isin os.environ["FASTNLP_BACKEND"] = "paddle" import pytest +from pathlib import Path from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import RandomBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -42,27 +41,101 @@ def prepare_test_save_load(): driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader -@pytest.mark.parametrize("reproducible", [True, False]) -@pytest.mark.parametrize("only_state_dict", [True, False]) -def test_save_and_load(prepare_test_save_load, reproducible, only_state_dict): +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict): """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" - driver1, driver2, dataloader = prepare_test_save_load - dataloader = driver1.set_dist_repro_dataloader(dataloader, "dist", reproducible) - driver1.save(path, {}, dataloader, only_state_dict, should_save_model=True) - driver2.load(path, dataloader, only_state_dict, should_load_model=True) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(80, 10) + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + ) + + # TODO 断点重训完善后在这里迭代几次 + + sampler_states = dataloader.batch_sampler.state_dict() + if only_state_dict: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + replaced_loader = states["dataloader"] + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] + + # 3. 检查 model 的参数是否被正确加载 + for batch in dataloader: + res1 = driver1.validate_step(batch) + res2 = driver2.validate_step(batch) + + assert paddle.equal_all(res1["pred"], res2["pred"]) + + # 4. 检查 batch_idx + # TODO + finally: + synchronize_safe_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(80, 10) + batch_sampler = BatchSampler(dataset=dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(dataset, True) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler + ) + + # TODO 断点重训完善后在这里迭代几次 + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + if only_state_dict: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + replaced_loader = states["dataloader"] + + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: res1 = driver1.validate_step(batch) res2 = driver2.validate_step(batch) assert paddle.equal_all(res1["pred"], res2["pred"]) + + # 4. 检查 batch_idx + # TODO finally: synchronize_safe_rm(path) @@ -144,24 +217,138 @@ class TestSingleDeviceFunction: """ self.driver.move_data_to_device(paddle.rand((32, 64))) - # @pytest.mark.parametrize( - # "dist_sampler", [ - # "dist", - # RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), - # RandomSampler(PaddleRandomMaxDataset(320, 10)) - # ] - # ) - # @pytest.mark.parametrize( - # "reproducible", - # [True, False] - # ) - # def test_set_dist_repro_dataloader(self, dist_sampler, reproducible): - # """ - # 测试set_dist_repro_dataloader函数 - # """ - # dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) - - # res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) + +class TestSetDistReproDataloder: + """ + 专门测试 set_dist_repro_dataloader 函数的类 + """ + def setup_method(self): + self.dataset = PaddleNormalDataset(20) + self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + model = PaddleNormalModel_Classification_1(10, 32) + self.driver = PaddleSingleDriver(model, device="cpu") + + def test_set_dist_repro_dataloader_with_reproducible_false(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 + 当dist为字符串时,此时应该返回原来的 dataloader + """ + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) + + assert replaced_loader is self.dataloader + + def test_set_dist_repro_dataloader_with_reproducible_true(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 + 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler + """ + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == self.dataloader.drop_last + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler + 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler + """ + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler is dist + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dist_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 + 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler + """ + dist = RandomSampler(self.dataset, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is self.dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is dist + assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + + # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + + def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): + """ + 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + # 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 + already_seen_idx = set() + for idx, batch in replaced_loader: + already_seen_idx.update(batch) + if idx >= 1: + break + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + print(sampler_states["data_idx"]) + + # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + replaced_loader.batch_sampler.load_state_dict(sampler_states) + else: + replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(replaced_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) + assert len(left_idxes | already_seen_idx) == len(self.dataset) class TestPaddleDriverFunctions: """ @@ -229,7 +416,7 @@ class TestPaddleDriverFunctions: with pytest.raises(ValueError): PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) - def test_check_dataloader_legacy_in_test(self): + def test_check_dataloader_legality_in_test(self): """ 测试is_train参数为False时,_check_dataloader_legality函数的表现 """ @@ -372,11 +559,78 @@ class TestPaddleDriverFunctions: dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_sampler_epoch(dataloader, 0) - def test_get_dataloader_args(self): + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args(self, batch_size, shuffle, drop_last): """ - 测试get_dataloader_args + 测试正常情况下 get_dataloader_args 的表现 """ - # 先确保不影响运行 - # TODO:正确性 - dataloader = DataLoader(PaddleNormalDataset()) - res = PaddleSingleDriver.get_dataloader_args(dataloader) \ No newline at end of file + dataloader = DataLoader( + PaddleNormalDataset(), + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + if shuffle: + assert isinstance(res.sampler, paddle.io.RandomSampler) + else: + assert isinstance(res.sampler, paddle.io.SequenceSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 batch_sampler 后 get_dataloader_args 的表现 + """ + dataset = PaddleNormalDataset() + dataloader = DataLoader( + dataset, + batch_sampler=RandomBatchSampler( + BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), + batch_size, + drop_last, + ) + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, RandomBatchSampler) + if shuffle: + assert isinstance(res.sampler, paddle.io.RandomSampler) + else: + assert isinstance(res.sampler, paddle.io.SequenceSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 sampler 后 get_dataloader_args 的表现 + """ + dataset = PaddleNormalDataset() + batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) + batch_sampler.sampler = RandomSampler(dataset, shuffle) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + ) + res = PaddleSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, PaddleNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + assert isinstance(res.sampler, RandomSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index b072d09d..690d0fb8 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -1,4 +1,56 @@ -import unittest +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "paddle" + +from fastNLP.core.drivers.paddle_driver.utils import ( + get_device_from_visible, + replace_batch_sampler, + replace_sampler, +) +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler import paddle -from paddle.io import Dataset, DataLoader, DistributedBatchSampler \ No newline at end of file +from paddle.io import DataLoader, BatchSampler + +from tests.helpers.datasets.paddle_data import PaddleNormalDataset + +@pytest.mark.parametrize( + ("user_visible_devices, cuda_visible_devices, device, output_type, correct"), + ( + ("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), + ("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), + ("3,4,5,6", "3,5", 0, int, 0), + ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), + ) +) +def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices + res = get_device_from_visible(device, output_type) + assert res == correct + +def test_replace_batch_sampler(): + dataset = PaddleNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + + replaced_loader = replace_batch_sampler(dataloader, batch_sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.dataset, PaddleNormalDataset) + assert len(replaced_loader.dataset) == len(dataset) + assert replaced_loader.batch_sampler.batch_size == 16 + +def test_replace_sampler(): + dataset = PaddleNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataloader, sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) \ No newline at end of file From 3ee6fc66f5b37d7cbd8ebbdfcc5ab02e002fab09 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 13 Apr 2022 15:37:08 +0800 Subject: [PATCH 17/26] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20on=5Fafter?= =?UTF-8?q?=5Foptimizers=5Fstep=20=E5=92=8C=20on=5Fafter=5Fzero=5Fgrad=20?= =?UTF-8?q?=20=E7=9A=84callback=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 22 ++++++++++++++++++- fastNLP/core/callbacks/callback_events.py | 4 +++- fastNLP/core/callbacks/callback_manager.py | 10 ++++++++- fastNLP/core/controllers/trainer.py | 21 +++++++++++++----- fastNLP/core/controllers/utils/utils.py | 10 +++++++-- fastNLP/core/drivers/torch_driver/ddp.py | 8 ------- .../drivers/torch_driver/single_device.py | 8 ------- .../core/drivers/torch_driver/torch_driver.py | 8 +++++++ tests/helpers/callbacks/helper_callbacks.py | 10 +++++++-- 9 files changed, 72 insertions(+), 29 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 96e4372b..0b9020fe 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -184,7 +184,7 @@ class Callback: """ pass - def on_before_optimizer_step(self, trainer, optimizers): + def on_before_optimizers_step(self, trainer, optimizers): """ 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 @@ -194,6 +194,16 @@ class Callback: """ pass + def on_after_optimizers_step(self, trainer, optimizers): + """ + 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ + pass + def on_before_zero_grad(self, trainer, optimizers): """ 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 @@ -204,6 +214,16 @@ class Callback: """ pass + def on_after_zero_grad(self, trainer, optimizers): + """ + 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ + pass + def on_validate_begin(self, trainer): """ 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 2bfe8e90..1c805ac2 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -92,8 +92,10 @@ class Events(EventEnum): ON_LOAD_CHECKPOINT = "on_load_checkpoint" ON_BEFORE_BACKWARD = "on_before_backward" ON_AFTER_BACKWARD = "on_after_backward" - ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" + ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" + ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" + ON_AFTER_ZERO_GRAD = "on_after_zero_grad" ON_VALIDATE_BEGIN = "on_validate_begin" ON_VALIDATE_END = "on_validate_end" diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 8b53c70b..a962fe9f 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -278,13 +278,21 @@ class CallbackManager: pass @_transfer - def on_before_optimizer_step(self, trainer, optimizers): + def on_before_optimizers_step(self, trainer, optimizers): + pass + + @_transfer + def on_after_optimizers_step(self, trainer, optimizers): pass @_transfer def on_before_zero_grad(self, trainer, optimizers): pass + @_transfer + def on_after_zero_grad(self, trainer, optimizers): + pass + @_transfer def on_validate_begin(self, trainer): pass diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index fb62c3f1..a78af9d8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): else: self.driver_name = driver.__class__.__name__ self.device = device + self.optimizers = optimizers self.fp16 = fp16 self.input_mapping = input_mapping self.output_mapping = output_mapping @@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger): 2. 函数作用 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 - 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / + 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", + "on_after_zero_grad") / ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", - "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", + "on_after_zero_grad") 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; @@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger): 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; """ if check_mode: - callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", + "on_before_zero_grad", "on_after_zero_grad") else: callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", - "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", + "on_before_zero_grad", "on_after_zero_grad") _not_called_callback_fns = [] for each_callback_fn in callbacks: if each_callback_fn in self.callback_manager.callback_fns: @@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger): def zero_grad(self): if (self.global_forward_batches + 1) % self.accumulation_steps == 0: - self.on_before_zero_grad(self.driver.optimizers) + self.on_before_zero_grad(self.optimizers) self.driver.zero_grad(self.set_grad_to_none) + self.on_after_zero_grad(self.optimizers) def step(self): if (self.global_forward_batches + 1) % self.accumulation_steps == 0: - self.on_before_optimizer_step(self.driver.optimizers) + self.on_before_optimizers_step(self.optimizers) self.driver.step() + self.on_after_optimizers_step(self.optimizers) def move_data_to_device(self, batch): return self.driver.move_data_to_device(batch) @@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger): + + diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index c3f6aeef..0dce0b27 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -68,12 +68,18 @@ class TrainerEventTrigger: def on_after_backward(self): self.callback_manager.on_after_backward(self) - def on_before_optimizer_step(self, optimizers): - self.callback_manager.on_before_optimizer_step(self, optimizers) + def on_before_optimizers_step(self, optimizers): + self.callback_manager.on_before_optimizers_step(self, optimizers) + + def on_after_optimizers_step(self, optimizers): + self.callback_manager.on_after_optimizers_step(self, optimizers) def on_before_zero_grad(self, optimizers): self.callback_manager.on_before_zero_grad(self, optimizers) + def on_after_zero_grad(self, optimizers): + self.callback_manager.on_after_zero_grad(self, optimizers) + def on_validate_begin(self): self.callback_manager.on_validate_begin(self) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 3537d0b3..11a61dde 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def is_global_zero(self): return self.global_rank == 0 diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 8cbb7acd..eda438d7 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): else: return self._train_step(batch) - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def validate_step(self, batch) -> Dict: # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index d2ffbac1..c8a086fe 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -72,6 +72,14 @@ class TorchDriver(Driver): p.grad.requires_grad_(False) p.grad.zero_() + def backward(self, loss): + self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + @staticmethod def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): if is_train: diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index a1697ab0..751d59f2 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): def on_after_backward(self, trainer): print("on_after_backward") - def on_before_optimizer_step(self, trainer, optimizers): - print("on_before_optimizer_step") + def on_before_optimizers_step(self, trainer, optimizers): + print("on_before_optimizers_step") + + def on_after_optimizers_step(self, trainer, optimizers): + print("on_after_optimizers_step") def on_before_zero_grad(self, trainer, optimizers): print("on_before_zero_grad") + def on_after_zero_grad(self, trainer, optimizers): + print("on_after_zero_grad") + def on_validate_begin(self, trainer): print("on_validate_begin") From f87723e2eb7fda7e2da203b9346bd76c0709f7ed Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 13 Apr 2022 08:40:26 +0000 Subject: [PATCH 18/26] small --- fastNLP/core/drivers/paddle_driver/fleet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 582ce542..4c937217 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -244,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver): """ if self.local_rank == 0: # 是 rank0 的话,则拉起其它子进程 - print("in launcher") launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) launcher.launch() # 设置参数和初始化分布式环境 From 9d71170bef82d01344684c4f3d40bf16b5be9e82 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 13 Apr 2022 17:04:33 +0800 Subject: [PATCH 19/26] =?UTF-8?q?=E8=A7=A3=E5=86=B3Trainer=E5=9C=A8?= =?UTF-8?q?=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E6=97=A0=E6=B3=95=E5=AE=9E=E7=8E=B0=E5=87=86=E7=A1=AEload?= =?UTF-8?q?=E5=92=8C=E4=BF=9D=E5=AD=98=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 10 +- fastNLP/core/controllers/trainer.py | 16 +++- fastNLP/core/controllers/utils/state.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 15 ++- .../samplers/reproducible_batch_sampler.py | 80 ++++++++++------ fastNLP/core/samplers/reproducible_sampler.py | 56 +++++++---- fastNLP/core/samplers/utils.py | 57 +++++++++++- fastNLP/envs/env.py | 2 + .../test_reproducible_batch_sampler.py | 93 +++++++++++++++++-- .../samplers/test_reproducible_sampler.py | 62 ++++++++++++- .../core/samplers/test_unrepeated_sampler.py | 6 +- 11 files changed, 325 insertions(+), 74 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 479686e1..2e3678d3 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -364,16 +364,16 @@ class _MetricsWrapper: else: args.append(batch) if not isinstance(outputs, dict): - raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" + raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_allennlp_metric(metric): - auto_param_call(metric.__call__, batch, *args) + auto_param_call(metric.__call__, outputs, *args) elif _is_paddle_metric(metric): - res = auto_param_call(metric.compute, batch, *args) + res = auto_param_call(metric.compute, outputs, *args) metric.update(res) def reset(self): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5daee856..6931ed3c 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; - 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 - 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; + 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 @@ -325,6 +325,8 @@ class Trainer(TrainerEventTrigger): try: while self.cur_epoch_idx < self.n_epochs: + # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save + self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch self.driver.set_model_mode("train") self.on_train_epoch_begin() self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) @@ -598,7 +600,9 @@ class Trainer(TrainerEventTrigger): # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; # 2. trainer_state; states = {"callback_states": self.on_save_checkpoint(), - "trainer_state": self.trainer_state.state_dict()} + "trainer_state": self.trainer_state.state_dict(), + 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) + } # 3. validate filter state; if self.evaluator is not None: @@ -675,9 +679,13 @@ class Trainer(TrainerEventTrigger): # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') + self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ + self.batch_idx_in_epoch + # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save + self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch # 5. 恢复所有 callback 的状态; - self.train_stepeckpoint(states["callback_states"]) + self.on_load_checkpoint(states["callback_states"]) self.driver.barrier() diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index fed9292c..2327c1e5 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -60,7 +60,7 @@ class TrainerState: cur_epoch_idx: 当前正在运行第几个 epoch; global_forward_batches: 当前模型总共 forward 了多少个 step; batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; - total_batches: 每一个 epoch 会 forward 多少个 step; + num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; """ n_epochs: Optional[int] = None # 无论如何重新算 diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index d2ffbac1..c79ecd0b 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -194,9 +194,20 @@ class TorchDriver(Driver): sampler = dataloader_args.sampler else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - + num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): - states['sampler_states'] = sampler.state_dict() + sampler_states = sampler.state_dict() + # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + if num_consumed_samples_array is not None: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." else: raise RuntimeError( 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index d1041f08..d4535bae 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -4,16 +4,18 @@ __all__ = [ ] import math -from array import array from copy import deepcopy from typing import Dict, Union, List from itertools import chain +import os import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.log import logger +from .utils import create_array, NumConsumedSamplesArray from abc import abstractmethod +from fastNLP.envs.env import FASTNLP_DEQUE_SIZE class ReproducibleBatchSampler: @@ -34,6 +36,13 @@ class ReproducibleBatchSampler: @abstractmethod def state_dict(self): + """ + 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 + 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward + 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 + + :return: + """ raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") @abstractmethod @@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.batch_size = batch_size self.drop_last = drop_last - self.data_idx = kwargs.get("data_idx", 0) + self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) self.index_list = kwargs.get("index_list", self._iterate_sampler()) self.need_reinitialize = kwargs.get("need_reinitialize", False) @@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; else: _index_lst.append(idx) - # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; - if len(_index_lst) > 4294967295: - # 注意 self.index_list 内存放的是全部数据的 index; - # unsigned long - _index_lst = array("L", _index_lst) - else: - # unsigned int - _index_lst = array("I", _index_lst) + _index_lst = create_array(len(_index_lst), _index_lst) return _index_lst def __iter__(self): if self.need_reinitialize: self.index_list = self._iterate_sampler() - self.data_idx = 0 + self.num_consumed_samples = 0 else: self.need_reinitialize = True batch = [] - if self.data_idx: - index_list = self.index_list[self.data_idx:] + if self.num_consumed_samples: + index_list = self.index_list[self.num_consumed_samples:] else: index_list = self.index_list + + # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 + # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + num_consumed_samples=self.num_consumed_samples) for idx in index_list: batch.append(idx) - self.data_idx += 1 if len(batch) == self.batch_size: + self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] + self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch batch = [] if len(batch) > 0 and not self.drop_last: + self.num_consumed_samples += len(batch) + self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch + # 需要重置防止边界条件问题 + self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def __len__(self) -> int: if self.drop_last: @@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): return (len(self.index_list) + self.batch_size - 1) // self.batch_size def state_dict(self) -> Dict: - return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} + states = { + "index_list": deepcopy(self.index_list), + "num_consumed_samples": self.num_consumed_samples, + 'sampler_type': self.__class__.__name__ + } + states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) + return states def load_state_dict(self, states: Dict): assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ @@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ "record and current dataset." self.index_list = _index_list - self.data_idx = states["data_idx"] + self.num_consumed_samples = states["num_consumed_samples"] self.need_reinitialize = False def set_distributed(self, num_replicas, rank, pad=True): @@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def batch_idx_in_epoch(self): if self.drop_last: - return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size + return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size else: return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ - (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size + (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size class BucketedBatchSampler(ReproducibleBatchSampler): @@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 - self.batch_size = batch_size self.num_batch_per_bucket = num_batch_per_bucket self.shuffle = shuffle @@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.rank = rank self.pad = pad - num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ - else len(self.dataset) - - if self.drop_last: - assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ - "than the number of replicates multiplied " \ - "with batch_size when drop_last=True." + # num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ + # else len(self.dataset) + # + # if self.drop_last: + # assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ + # "than the number of replicates multiplied " \ + # "with batch_size when drop_last=True." return self @@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) - def __len__(self): + def __len__(self)->int: """ 返回当前 sampler 还会返回多少个 batch 的数据 @@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + num_consumed_samples=self.num_consumed_samples) for batch in batches: self.num_consumed_samples += self.num_replicas * len(batch) + self.num_consumed_samples_array.push(self.num_consumed_samples) yield list(map(int, batch)) self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') self.old_batch_size = self.batch_size self.old_num_batch_per_bucket = self.num_batch_per_bucket self.old_num_replicas = self.num_replicas @@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): 'num_batch_per_bucket': self.num_batch_per_bucket, 'num_replicas': self.num_replicas } + + states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index f48e2fc6..396e69b2 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,9 +1,14 @@ from typing import Dict, List, Union import math +import os + import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet +from fastNLP.envs.env import FASTNLP_DEQUE_SIZE +from .utils import NumConsumedSamplesArray + __all__ = [ 'ReproducibleSampler', @@ -30,6 +35,13 @@ class ReproducibleSampler: raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") def state_dict(self): + """ + 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 + 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward + 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 + + :return: + """ raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") def load_state_dict(self, states): @@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): indices = indices[self.num_consumed_samples:] indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): return indices def state_dict(self) -> Dict: - states = { - 'seed': self.seed, - 'epoch': self.epoch, - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - 'shuffle': self.shuffle - } + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." @@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] - if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 if self.shuffle != states['shuffle']: logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " @@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): return list(range(len(self.dataset))) def state_dict(self) -> Dict: - states = { - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - } + states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, + 'length': len(self.dataset), + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." @@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index dd90fe7c..80af1787 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -2,6 +2,9 @@ __all__ = [ 're_instantiate_sampler', 'conversion_between_reproducible_and_unrepeated_sampler' ] +from array import array +from typing import Sequence +from collections import deque from fastNLP.core.samplers.unrepeated_sampler import * from fastNLP.core.samplers.reproducible_sampler import * @@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): all_attributes = vars(sampler) if new_sampler_class is not None: return new_sampler_class(**all_attributes) - return type(sampler)(**all_attributes) \ No newline at end of file + return type(sampler)(**all_attributes) + + +def create_array(length, fill_value) -> array: + """ + 根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' + + :param length: + :param fill_value: + :return: + """ + if not isinstance(fill_value, Sequence): + fill_value = [fill_value]*length + + if length > 4294967295: + _index_lst = array("L", fill_value) + else: + _index_lst = array("I", fill_value) + return _index_lst + + +class NumConsumedSamplesArray: + def __init__(self, buffer_size=2000, num_consumed_samples=0): + """ + 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 + ex: + array = NumConsumedSamplesArray(buffer_size=3) + for i in range(10): + array.push(i) + + array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 + array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] + + :param buffer_size: 报错多少个历史。 + :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 + """ + self.count = 0 + self.deque = deque(maxlen=buffer_size) + if num_consumed_samples is not None: + self.push(num_consumed_samples) + self.buffer_size = buffer_size + + def __getitem__(self, item): + if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 + return 0 + assert isinstance(item, int), "Only int index allowed." + assert self.count-len(self.deque)<=item num_samples: + if num_replicas*batch_size > num_samples: return num_batch_per_bucket = 10 samplers = [] lengths = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) - sampler.set_distributed(num_replica, rank=i, pad=pad) + sampler.set_distributed(num_replicas, rank=i, pad=pad) sampler.set_epoch(0) samplers.append(sampler) lengths.append(len(list(iter(sampler)))) assert len(set(lengths))==1 - bucket_diff = batch_size * num_batch_per_bucket * num_replica + bucket_diff = batch_size * num_batch_per_bucket * num_replicas for bs in zip(*samplers): diff = max(chain(*bs)) - min(chain(*bs)) assert diff <= bucket_diff + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): + """ + 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 + 偏多 + + :return: + """ + batch_size = 6 + num_batch_per_bucket = 10 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + for i in range(num_replicas): + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) + + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for batchs in zip(*samplers): + batch = chain(*batchs) + already_seen_set.update(batch) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.update(batch) + assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( + dataset) + + # 测试保存之后再次保存 + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for batch in sampler: + already_seen_set.update(batch) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for batch in sampler: + already_seen_set.update(batch) + + assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 981d6a03..ddf52bcb 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -3,6 +3,7 @@ import pytest from functools import partial from itertools import chain +from copy import deepcopy from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -180,6 +181,63 @@ class TestRandomSamplerYh: assert seen <= 1 if pad else seen == 0 assert seen_in_other_rank<=1 # 因为pad可能重复 + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): + # 测试在 sampler 多生成的时候,可以仍然可以恢复 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + for i in range(num_replicas): + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for idxes in zip(*samplers): + already_seen_set.update(idxes) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + sampler = RandomSampler(dataset, shuffle=shuffle) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.add(batch) + assert len(already_seen_set) == len(dataset) + # 测试保存之后再次保存 + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.set_epoch(0) + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for idx in sampler: + already_seen_set.add(idx) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for idx in sampler: + already_seen_set.add(idx) + + assert len(already_seen_set)==len(dataset) + class TestRandomSampler: # 测试单卡; @@ -386,7 +444,7 @@ class TestSortedSampler: assert indexes==list(range(num_of_data-1, -1, -1)) @pytest.mark.parametrize('pad', [True, False]) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, pad, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) @@ -540,7 +598,7 @@ class TestSequentialSampler: assert indexes==list(range(num_of_data)) @pytest.mark.parametrize('pad', [True, False]) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, pad, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py index 09601d2c..4a271f41 100644 --- a/tests/core/samplers/test_unrepeated_sampler.py +++ b/tests/core/samplers/test_unrepeated_sampler.py @@ -25,7 +25,7 @@ class TestUnrepeatedSampler: indexes = set(sampler) assert indexes==set(range(num_of_data)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('shuffle', [False, True]) def test_multi(self, num_replica, num_of_data, shuffle): @@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: indexes = list(sampler) assert indexes==list(range(num_of_data-1, -1, -1)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) @@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: indexes = list(sampler) assert indexes==list(range(num_of_data)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) From d2439fe443fbbac0bec5de7368c3e43998c3a79f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 13 Apr 2022 09:05:21 +0000 Subject: [PATCH 20/26] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=5FMetricsWrapper=20upd?= =?UTF-8?q?ate=E4=BC=A0=E5=8F=82=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 479686e1..2e3678d3 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -364,16 +364,16 @@ class _MetricsWrapper: else: args.append(batch) if not isinstance(outputs, dict): - raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" + raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_allennlp_metric(metric): - auto_param_call(metric.__call__, batch, *args) + auto_param_call(metric.__call__, outputs, *args) elif _is_paddle_metric(metric): - res = auto_param_call(metric.compute, batch, *args) + res = auto_param_call(metric.compute, outputs, *args) metric.update(res) def reset(self): From b9b0b5343036b47654895bebc20249a3c8882ec0 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 13 Apr 2022 19:09:27 +0800 Subject: [PATCH 21/26] =?UTF-8?q?=E5=B0=86=20Events=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E5=B0=8F=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_events.py | 48 +++++++++---------- .../test_checkpoint_callback_torch.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 1c805ac2..7a25c45a 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -74,30 +74,30 @@ class EventEnum(_SingleEventState, Enum): @unique class Events(EventEnum): - ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" - ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" - ON_SANITY_CHECK_END = "on_sanity_check_end" - ON_TRAIN_BEGIN = "on_train_begin" - ON_TRAIN_END = "on_train_end" - ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" - ON_TRAIN_EPOCH_END = "on_train_epoch_end" - ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" - ON_FETCH_DATA_END = "on_fetch_data_end" - ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" - ON_TRAIN_BATCH_END = "on_train_batch_end" - ON_EXCEPTION = "on_exception" - ON_SAVE_MODEL = "on_save_model" - ON_LOAD_MODEL = "on_load_model" - ON_SAVE_CHECKPOINT = "on_save_checkpoint" - ON_LOAD_CHECKPOINT = "on_load_checkpoint" - ON_BEFORE_BACKWARD = "on_before_backward" - ON_AFTER_BACKWARD = "on_after_backward" - ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" - ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" - ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" - ON_AFTER_ZERO_GRAD = "on_after_zero_grad" - ON_VALIDATE_BEGIN = "on_validate_begin" - ON_VALIDATE_END = "on_validate_end" + on_after_trainer_initialized = "on_after_trainer_initialized" + on_sanity_check_begin = "on_sanity_check_begin" + on_sanity_check_end = "on_sanity_check_end" + on_train_begin = "on_train_begin" + on_train_end = "on_train_end" + on_train_epoch_begin = "on_train_epoch_begin" + on_train_epoch_end = "on_train_epoch_end" + on_fetch_data_begin = "on_fetch_data_begin" + on_fetch_data_end = "on_fetch_data_end" + on_train_batch_begin = "on_train_batch_begin" + on_train_batch_end = "on_train_batch_end" + on_exception = "on_exception" + on_save_model = "on_save_model" + on_load_model = "on_load_model" + on_save_checkpoint = "on_save_checkpoint" + on_load_checkpoint = "on_load_checkpoint" + on_before_backward = "on_before_backward" + on_after_backward = "on_after_backward" + on_before_optimizers_step = "on_before_optimizers_step" + on_after_optimizers_step = "on_after_optimizers_step" + on_before_zero_grad = "on_before_zero_grad" + on_after_zero_grad = "on_after_zero_grad" + on_validate_begin = "on_validate_begin" + on_validate_end = "on_validate_end" class EventsList: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 557c31b2..fe0a3582 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: raise NotImplementedError diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 0da8c976..82fa3af0 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -254,7 +254,7 @@ def test_trainer_on_exception( ): from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == cur_rank: raise NotImplementedError From 2f23d80ccc19645bca43d44cdefd208778065a6f Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 14 Apr 2022 00:45:17 +0800 Subject: [PATCH 22/26] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20trainer=20?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=20validate=20=E7=9A=84=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_events.py | 16 +---- fastNLP/core/controllers/trainer.py | 60 +++++++++---------- .../test_trainer_w_evaluator_torch.py | 44 +++++++++++++- 3 files changed, 73 insertions(+), 47 deletions(-) diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 7a25c45a..ef972b35 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -171,20 +171,8 @@ class Filter: self.num_called += 1 # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; - # 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; - # 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; - - # 提取参数的逻辑; - trainer = kwargs.get("trainer", None) - - if trainer is None and len(args) > 0: - trainer = args[0] - if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer - # 也会调用这个 module,但是 Controller 不会; - param = (self, trainer) - else: - param = (self, ) - if self._filter(*param): + trainer = args[0] + if self._filter(self, trainer): self.num_executed += 1 return fn(*args, **kwargs) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d8e984a1..e1f31375 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -224,13 +224,14 @@ class Trainer(TrainerEventTrigger): # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; self.evaluator = None - self.epoch_validate = lambda *args, **kwargs: ... - self.step_validate = lambda *args, **kwargs: ... self.monitor = monitor self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") + if callable(validate_every): + logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " + "and in this way, the kind of controlling frequency is depending on the 'step'.") self.evaluator = Evaluator( model=model, @@ -248,16 +249,6 @@ class Trainer(TrainerEventTrigger): progress_bar=kwargs.get('progress_bar', 'auto') ) - if callable(validate_every): - self._step_validate_filter = Filter(filter_fn=validate_every) - logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " - "and in this way, the kind of controlling frequency is depending on the 'step'.") - elif validate_every < 0: - self._epoch_validate_filter = Filter(every=-validate_every) - else: - # validate_every > 0 - self._step_validate_filter = Filter(every=validate_every) - self.metrics = metrics self.validate_every = validate_every @@ -356,31 +347,38 @@ class Trainer(TrainerEventTrigger): raise e def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): - def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: + def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: trainer.on_validate_begin() _validate_res: dict = validate_fn() trainer.on_validate_end(_validate_res) + self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + + def step_validate(self): if self.evaluator is not None: + should_run_validate = False + if callable(self.validate_every): - self.step_validate = self._step_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) - elif self.validate_every < 0: - self.epoch_validate = self._epoch_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) - else: - # validate_every > 0 - self.step_validate = self._step_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) + if self.validate_every(self): + should_run_validate = True + elif self.validate_every > 0: + if self.global_forward_batches % self.validate_every == 0: + should_run_validate = True + + if should_run_validate: + self.validate_fn() + + def epoch_validate(self): + if self.evaluator is not None: + should_run_validate = False + + if isinstance(self.validate_every, int) and self.validate_every < 0: + validate_every = -self.validate_every + if self.cur_epoch_idx % validate_every == 0: + should_run_validate = True + + if should_run_validate: + self.validate_fn() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): r""" diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 699ee3b9..70d03f8c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -98,14 +98,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) +@pytest.mark.parametrize("validate_every", [-3]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, callbacks, + validate_every, n_epochs=10, ): trainer = Trainer( @@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator( input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, + validate_every=validate_every, n_epochs=n_epochs, callbacks=callbacks, output_from_new_proc="all" - ) trainer.run() @@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( dist.destroy_process_group() +@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) +@magic_argv_env_context +def test_trainer_validate_every( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=6, +): + + def validate_every(trainer): + if trainer.global_forward_batches % 10 == 0: + print(trainer) + print("\nfastNLP test validate every.\n") + print(trainer.global_forward_batches) + return True + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + validate_dataloaders=model_and_optimizers.validate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + output_from_new_proc="all", + validate_every=validate_every + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + From 1452aa8f6c54e2ad313d93687baf491b9cb37559 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 14 Apr 2022 13:50:53 +0800 Subject: [PATCH 23/26] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=20dist=20?= =?UTF-8?q?=E4=B8=BA=20None=20=E6=97=B6=E7=9A=84=20set=5Fdist=5Frepro=5Fda?= =?UTF-8?q?taloader=20=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/ddp.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 11a61dde..c673fe62 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver): raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: - if isinstance(dist, ReproducibleBatchSampler): - dist = re_instantiate_sampler(dist) - return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleSampler): - dist = re_instantiate_sampler(dist) - return replace_sampler(dataloader, dist) + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) + if isinstance(args.sampler, ReproducibleSampler): + return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) return dataloader # trainer elif dist == "dist": From 64fa182aeb02d7e06882eede6b27aa5b311c9658 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Thu, 14 Apr 2022 08:01:04 +0000 Subject: [PATCH 24/26] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=AD=E7=82=B9?= =?UTF-8?q?=E9=87=8D=E8=AE=AD=E9=83=A8=E5=88=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 16 ++++++++-------- .../drivers/paddle_driver/paddle_driver.py | 19 ++++++++++++++++--- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 4c937217..3f29e4dd 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -325,7 +325,6 @@ class PaddleFleetDriver(PaddleDriver): assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; - # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; if isinstance(dist, ReproducibleBatchSampler): dist.set_distributed( num_replicas=self.world_size, @@ -345,15 +344,16 @@ class PaddleFleetDriver(PaddleDriver): # trainer, evaluator if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " + raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " "control.") else: - if isinstance(dist, ReproducibleBatchSampler): - dist = re_instantiate_sampler(dist) - return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleSampler): - dist = re_instantiate_sampler(dist) - return replace_sampler(dataloader, dist) + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + if isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) return dataloader # trainer elif dist == "dist": diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 4362dcce..cc870536 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -66,8 +66,8 @@ class PaddleDriver(Driver): :param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 """ - # if set_to_none: - # log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") + if set_to_none: + logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") for optimizer in self.optimizers: optimizer.clear_grad() @@ -254,8 +254,21 @@ class PaddleDriver(Driver): else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): - states['sampler_states'] = sampler.state_dict() + sampler_states = sampler.state_dict() + # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + if num_consumed_samples_array is not None: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + else: raise RuntimeError( 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') From 7c6e8b20a8612a20d2dfe077896fe537ac7c2d1b Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Thu, 14 Apr 2022 08:01:14 +0000 Subject: [PATCH 25/26] small --- fastNLP/core/samplers/reproducible_batch_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index d4535bae..73621b5f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -151,7 +151,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.need_reinitialize = False def set_distributed(self, num_replicas, rank, pad=True): - raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") + raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") def set_epoch(self, epoch): if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): From 16a467393c714b249530d349d65739f797d64c62 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 14 Apr 2022 16:02:41 +0800 Subject: [PATCH 26/26] =?UTF-8?q?1.montior=E5=85=81=E8=AE=B8=E4=BC=A0?= =?UTF-8?q?=E5=85=A5callable=E7=9A=84=E5=AF=B9=E8=B1=A1=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E9=80=89=E6=8B=A9;=202.=E8=A7=A3=E5=86=B3Sampler=E4=B8=AD?= =?UTF-8?q?=E5=AD=98=E5=9C=A8=E7=9A=84=E5=BE=AA=E7=8E=AF=E5=BC=95=E7=94=A8?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 40 +++- fastNLP/core/callbacks/checkpoint_callback.py | 11 +- fastNLP/core/callbacks/early_stop_callback.py | 11 +- .../callbacks/load_best_model_callback.py | 11 +- fastNLP/core/callbacks/progress_callback.py | 4 +- fastNLP/core/callbacks/utils.py | 18 +- fastNLP/core/collators/collator.py | 23 ++- fastNLP/core/controllers/evaluator.py | 19 +- fastNLP/core/controllers/trainer.py | 64 +++--- fastNLP/core/controllers/utils/utils.py | 10 +- fastNLP/core/dataset/dataset.py | 12 +- .../drivers/jittor_driver/jittor_driver.py | 4 +- .../drivers/paddle_driver/paddle_driver.py | 2 +- .../core/drivers/torch_driver/dist_utils.py | 10 +- .../core/drivers/torch_driver/torch_driver.py | 2 +- fastNLP/core/log/logger.py | 6 +- fastNLP/core/metrics/accuracy.py | 5 +- fastNLP/core/samplers/__init__.py | 6 +- fastNLP/core/samplers/conversion_utils.py | 33 ++++ .../samplers/reproducible_batch_sampler.py | 19 +- fastNLP/core/samplers/reproducible_sampler.py | 14 +- fastNLP/core/samplers/utils.py | 34 +--- fastNLP/core/utils/__init__.py | 3 +- fastNLP/core/utils/utils.py | 143 +++++++++++--- fastNLP/io/data_bundle.py | 10 +- fastNLP/io/pipe/classification.py | 2 +- fastNLP/io/pipe/construct_graph.py | 2 +- fastNLP/io/pipe/pipe.py | 2 +- tests/core/utils/test_utils.py | 187 ++++++++++++++++++ tests/helpers/utils.py | 20 +- 30 files changed, 505 insertions(+), 222 deletions(-) create mode 100644 fastNLP/core/samplers/conversion_utils.py create mode 100644 tests/core/utils/test_utils.py diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 0b9020fe..902421c8 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -10,6 +10,7 @@ from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection +from fastNLP.core.utils.utils import _check_valid_parameters_number class Callback: @@ -299,7 +300,11 @@ class HasMonitorCallback(Callback): self.must_have_moinitor = must_have_monitor def set_monitor(self, monitor, larger_better): - self.monitor = str(monitor) if monitor is not None else None + if callable(monitor): # 检查是否能够接受一个参数 + _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') + self.monitor = monitor + else: + self.monitor = str(monitor) if monitor is not None else None self.larger_better = bool(larger_better) if larger_better: self.monitor_value = float('-inf') @@ -322,24 +327,33 @@ class HasMonitorCallback(Callback): raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " f"You can set it in the initialization or through Trainer.") - def get_monitor_value(self, results:Dict)->float: + def get_monitor_value(self, results:Dict)->Union[float, None]: """ 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 :param results: - :return: + :return: 如果为 None ,表明此次没有找到合适的monitor """ if len(results)==0: - return 0 + return None # 保证所有的 tensor 都被转换为了 python 特定的类型 results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=results) - if self._real_monitor != use_monitor: # 发生了替换需要打印 - logger.warning( - f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") + if monitor_value is None: + return monitor_value + # 第一次运行 + if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: + logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + # 检测到此次和上次不同。 + elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: + logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " + f"The expected monitor is:`{self.monitor}`, last used monitor is:" + f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " + f"customized monitor function when the evaluation results are varying between validation.") + self._real_monitor = use_monitor return monitor_value @@ -347,10 +361,12 @@ class HasMonitorCallback(Callback): """ 检测 monitor_value 是否是更好的 - :param monitor_value: + :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 :return: """ + if monitor_value is None: + return False better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) if keep_if_better and better: self.monitor_value = monitor_value @@ -364,6 +380,12 @@ class HasMonitorCallback(Callback): :param monitor_value2: :return: """ + if monitor_value1 is None and monitor_value2 is None: + return True + if monitor_value1 is None: + return False + if monitor_value2 is None: + return True better = False if (self.larger_better and monitor_value1 > monitor_value2) or \ (not self.larger_better and monitor_value1 < monitor_value2): diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 82bfe404..a5be2b4c 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -10,8 +10,7 @@ from copy import deepcopy import fastNLP -from .callback import Callback, HasMonitorCallback -from fastNLP.core.callbacks.utils import _get_monitor_value +from .callback import HasMonitorCallback from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir @@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): """ if self.save_topk is not None: monitor_value = self.get_monitor_value(results=results) + if monitor_value is None: + return folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ f"-{self._real_monitor}_{monitor_value}" @@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), + 返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), + 返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index 602236f7..b1842d43 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -2,17 +2,18 @@ __all__ = [ 'EarlyStopCallback' ] -from typing import Dict +from typing import Dict, Union, Callable from .callback import HasMonitorCallback from fastNLP.core.utils.exceptions import EarlyStopException class EarlyStopCallback(HasMonitorCallback): - def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): + def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): """ - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 + evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param patience: 多少次 validate 不没有提升就停止。 """ @@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): self.patience = patience def on_validate_end(self, trainer, results): - if len(results)==0: - return monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return if self.is_better_monitor_value(monitor_value, keep_if_better=True): self.wait = 0 else: diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 9a4bb65f..e068326b 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -3,7 +3,7 @@ __all__ = [ ] import os -from typing import Optional, Callable +from typing import Optional, Callable, Union from .callback import HasMonitorCallback from io import BytesIO import shutil @@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call class LoadBestModelCallback(HasMonitorCallback): - def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, + def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None, delete_after_train:bool = True): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 + evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): self.get_monitor_value(sanity_check_res) def on_validate_end(self, trainer, results): - if len(results)==0: - return monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 756d236b..67176387 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 + 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ @@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( + 字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 2720ba3f..7ece3bb9 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -1,9 +1,10 @@ -from typing import Optional +from typing import Optional, Union from fastNLP.core.log.logger import logger from difflib import SequenceMatcher +from fastNLP.core.utils.utils import _get_fun_msg -def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): +def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): """ 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 匹配。 @@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( :param monitor: :param real_monitor: :param res: - :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value + :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 + 找到对应的 monitor """ if len(res)==0: - return monitor, 0 + return monitor, None + + if callable(monitor): + try: + monitor_value = monitor(res) + except BaseException as e: + logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") + raise e + return monitor, monitor_value if monitor in res: return monitor, res[monitor] diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index f468dd4c..b6b6de14 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -5,7 +5,7 @@ __all__ = [ from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Callable, Union +from typing import Any, Dict, List, Callable, Union, Tuple from numbers import Number import warnings @@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): self.field_name = field_name # 标示当前 field 的名称 -def _get_ele_type_and_dim(cell: Any, dim=0): +def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: r""" 识别cell的类别与dimension的数量 @@ -206,7 +206,7 @@ class AutoCollator(Collator): def __init__(self, as_numpy: bool): super(AutoCollator, self).__init__() self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 - self.need_inputs = [] # 需要的 field name + self.need_inputs = set() # 需要的 field name self.field_dtypes = None # 每列数据单元的 dtype 类型 self.field_dims = None # 每列数据单元维度 self.as_numpy = as_numpy @@ -214,10 +214,17 @@ class AutoCollator(Collator): def __call__(self, ins_lst: List[Dict]) -> dict: if len(self.need_inputs) == 0: raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) + # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 + # 第一种情况,设置了 set_input 的值 # 第二种情况, 根据数据的类型的判断是否 padding if self.field_dtypes is None and self.field_dims is None: - self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) + field_dtypes, field_dims = {}, {} + for key, value in ins_lst[0].items(): + if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: + field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) + self.field_dtypes = field_dtypes + self.field_dims = field_dims pack_ins_lst, pad_ins_lst = {field_name: [] for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} @@ -233,13 +240,13 @@ class AutoCollator(Collator): if len(self.pad_field_value.keys()) > 0: # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 - drop_field_names = [] + non_pad_field_names = [] for k, v in self.pad_field_value.items(): if v is None: - drop_field_names.append(k) + non_pad_field_names.append(k) # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) - for field_name in drop_field_names: + for field_name in non_pad_field_names: field_array = pack_ins_lst.pop(field_name) pad_ins_lst[field_name] = np.array(field_array) @@ -269,7 +276,7 @@ class AutoCollator(Collator): def set_input(self, *field_names): for field_name in field_names: - self.need_inputs.append(field_name) + self.need_inputs.add(field_name) def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 2e3678d3..5196f8c7 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -11,11 +11,12 @@ __all__ = [ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from .loops import Loop, EvaluateBatchLoop -from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ +from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ match_and_substitute_params, f_rich_progress from fastNLP.core.metrics import Metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader +from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.core.log import logger @@ -38,11 +39,11 @@ class Evaluator: driver: Union[str, Driver] = 'single', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, - mode: str = "validate", + mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, - fp16: Optional[bool] = False, + fp16: bool = False, verbose: int = 1, **kwargs ): @@ -92,8 +93,8 @@ class Evaluator: self.device = device self.verbose = verbose - assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ - "two parameters." + if batch_step_fn is not None: + _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') self.batch_step_fn = batch_step_fn self.mode = mode @@ -135,6 +136,7 @@ class Evaluator: if self.progress_bar == 'auto': self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' + self.driver.check_evaluator_mode(self.mode) self.driver.barrier() def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: @@ -154,8 +156,6 @@ class Evaluator: assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." - self.driver.check_evaluator_mode(self.mode) - if self.mode == 'validate': assert self.driver.has_validate_dataloaders() else: @@ -367,9 +367,10 @@ class _MetricsWrapper: raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, outputs, *args) + # 这样在 auto_param_call 报错的时候才清晰。 + auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, outputs, *args) + auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) elif _is_allennlp_metric(metric): auto_param_call(metric.__call__, outputs, *args) elif _is_paddle_metric(metric): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e1f31375..66e88827 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -14,6 +14,7 @@ __all__ = [ from .loops import Loop, TrainBatchLoop from .utils import State, TrainerState +from .utils.utils import check_validate_every from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter @@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver -from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext +from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext +from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME @@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger): validate_dataloaders=None, batch_step_fn: Optional[Callable] = None, validate_batch_step_fn: Optional[Callable] = None, - validate_mode: str = "validate", + validate_mode: Union[str, callable] = 'validate', callbacks: Union[List[Callback], Callback, None] = None, metrics: Optional[dict] = None, validate_every: Optional[Union[int, callable]] = -1, @@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger): model_wo_auto_param_call: bool = False, accumulation_steps: int = 1, fp16: bool = False, - monitor: str = None, + monitor: Union[str, callable] = None, larger_better: bool = True, marker: Optional[str] = None, **kwargs @@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger): :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; - 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 - 中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed - 表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 - 例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch - 之后,每个 epoch 结束进行一次 validate 。 + 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 + 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 @@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger): :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param kwargs: 一些其它的可能需要的参数; @@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger): self.input_mapping = input_mapping self.output_mapping = output_mapping - assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ - "two parameters." self.batch_step_fn = batch_step_fn if batch_step_fn is not None: + _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) else: self.check_batch_step_fn = lambda *args, **kwargs: ... @@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger): if metrics is not None and validate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") - # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; - # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; self.evaluator = None self.monitor = monitor self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: - if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): - raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") - if callable(validate_every): - logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " - "and in this way, the kind of controlling frequency is depending on the 'step'.") - + check_validate_every(validate_every) self.evaluator = Evaluator( model=model, dataloaders=validate_dataloaders, @@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger): _validate_res: dict = validate_fn() trainer.on_validate_end(_validate_res) - self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) def step_validate(self): - if self.evaluator is not None: - should_run_validate = False + """ + 在每个 batch 结束后调用,根据设置执行 evaluate 。 + :return: + """ + if self.evaluator is not None: if callable(self.validate_every): if self.validate_every(self): - should_run_validate = True - elif self.validate_every > 0: - if self.global_forward_batches % self.validate_every == 0: - should_run_validate = True - - if should_run_validate: - self.validate_fn() + self.run_evaluate() + elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: + self.run_evaluate() def epoch_validate(self): - if self.evaluator is not None: - should_run_validate = False + """ + 在每个 epoch 结束后调用,根据设置执行 evaluate 。 + :return: + """ + if self.evaluator is not None: if isinstance(self.validate_every, int) and self.validate_every < 0: validate_every = -self.validate_every if self.cur_epoch_idx % validate_every == 0: - should_run_validate = True - - if should_run_validate: - self.validate_fn() + self.run_evaluate() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): r""" @@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger): def wrapper(fn: Callable) -> Callable: cls._custom_callbacks[marker].append((event, fn)) callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] - assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ - f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ - f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." + _check_valid_parameters_number(fn, callback_fn_args) return fn return wrapper diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index 0dce0b27..6e0824a1 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -1,8 +1,9 @@ -from collections.abc import Iterator +import inspect from typing import Dict from fastNLP.core.callbacks import CallbackManager from .state import TrainerState +from fastNLP.core.utils.utils import _check_valid_parameters_number class TrainerEventTrigger: @@ -125,5 +126,8 @@ class _TruncatedDataLoader: return getattr(self.dataloader, item) - - +def check_validate_every(validate_every): + if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): + raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") + if callable(validate_every): + _check_valid_parameters_number(validate_every, expected_params=['trainer']) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 5b8ec635..cd887253 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -178,10 +178,11 @@ class DataSet: elif isinstance(idx, slice): if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") - data_set = DataSet() + dataset = DataSet() for field_name, field in self.field_arrays.items(): - data_set.add_field(field_name=field_name, fields=field.content[idx]) - return data_set + dataset.add_field(field_name=field_name, fields=field.content[idx]) + dataset.collate_fns = deepcopy(self.collate_fns) + return dataset elif isinstance(idx, str): if idx not in self: raise KeyError("No such field called {} in DataSet.".format(idx)) @@ -192,6 +193,7 @@ class DataSet: assert isinstance(i, int), "Only int index allowed." instance = self[i] dataset.append(instance) + dataset.collate_fns = deepcopy(self.collate_fns) return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) @@ -674,6 +676,8 @@ class DataSet: dev_set.append(self[idx]) for idx in train_indices: train_set.append(self[idx]) + dev_set.collate_fns = deepcopy(self.collate_fns) + train_set.collate_fns = deepcopy(self.collate_fns) return dev_set, train_set @@ -795,7 +799,7 @@ class DataSet: :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 :return: """ - # TODO 需要去重复 + # TODO 不能为空 for field_name in field_names: self.collate_fns.set_pad_val(field_name, val=val) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index a8ad32e8..411fdf69 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -66,7 +66,7 @@ class JittorDriver(Driver): if mode == "validate": if not hasattr(model, "validate_step"): if hasattr(model, "test_step"): - logger.warning( + logger.warning_once( "Your model does not have 'validate_step' method but has 'test_step' method, but you" "are using 'mode=validate', we are going to use 'test_step' to substitute for" "'validate_step'.") @@ -74,7 +74,7 @@ class JittorDriver(Driver): else: if not hasattr(model, "test_step"): if hasattr(model, "validate_step"): - logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" + logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" "are using 'mode=test', we are going to use 'validate_step' to substitute for" "'test_step'.") diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 4362dcce..931921fd 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -133,7 +133,7 @@ class PaddleDriver(Driver): else: if not hasattr(model, "test_step"): if hasattr(model, "validate_step"): - logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" + logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" "'test_step'.") diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index ad9e6794..37110577 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None): >>> output ['foo', 12, {1: 2}] """ - if dist._rank_not_in_group(group): + if dist.distributed_c10d._rank_not_in_group(group): return - - input_tensor, local_size = _object_to_tensor(obj) if _TORCH_GREATER_EQUAL_1_8: current_device = torch.device("cpu") is_nccl_backend = _check_for_nccl_backend(group) @@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None): # We cannot simply use my_rank since rank == device is not necessarily # true. current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) else: current_device = torch.cuda.current_device() + + input_tensor, local_size = _object_to_tensor(obj, device=current_device) + # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = dist.get_world_size(group=group) @@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): tensor = tensor.cpu() tensor_size = object_size_list[i] object_list[i] = _tensor_to_object(tensor, tensor_size) + return object_list \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index c60d1552..f1e33d5e 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -113,7 +113,7 @@ class TorchDriver(Driver): if mode == "validate": if not hasattr(model, "validate_step"): if hasattr(model, "test_step"): - logger.warning( + logger.warning_once( "Your model does not have 'validate_step' method but has 'test_step' method, but you" "are using 'mode=validate', we are going to use 'test_step' to substitute for" "'validate_step'.") diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 9763ab4a..004bfb16 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): self._warning_msgs.add(msg) def warn(self, msg, *args, **kwargs): - warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) - self.warning(msg, *args, **kwargs) + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) def error(self, msg, *args, **kwargs): """ diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 0a60e4d7..d1ac1776 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask class Accuracy(Metric): - def __init__(self, backend: Union[str, Backend, None] = 'auto', - aggregate_when_get_metric: bool = True): + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) @@ -64,7 +63,7 @@ class Accuracy(Metric): warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: - raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " + raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " f"{pred.shape[:-1]}, got {target.shape}.") if masks is not None: diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index c3cc2d39..61433e8e 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -23,14 +23,14 @@ __all__ = [ "BucketedBatchSampler", "ReproducibleBatchSampler", - "re_instantiate_sampler", - "conversion_between_reproducible_and_unrepeated_sampler" + "re_instantiate_sampler" ] from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler -from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler +from .utils import re_instantiate_sampler +from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler diff --git a/fastNLP/core/samplers/conversion_utils.py b/fastNLP/core/samplers/conversion_utils.py new file mode 100644 index 00000000..d5d97d0c --- /dev/null +++ b/fastNLP/core/samplers/conversion_utils.py @@ -0,0 +1,33 @@ +from fastNLP.core.samplers import re_instantiate_sampler +from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ + SortedSampler +from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ + UnrepeatedSequentialSampler, UnrepeatedSortedSampler + + +def conversion_between_reproducible_and_unrepeated_sampler(sampler): + """ + 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 + ReproducibleSampler, + + :param sampler: + :return: + """ + assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ + "The sampler must be UnrepeatedSampler or ReproducibleSampler" + if isinstance(sampler, UnrepeatedSampler): + if isinstance(sampler, UnrepeatedRandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) + elif isinstance(sampler, UnrepeatedSequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) + elif isinstance(sampler, UnrepeatedSortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) + raise TypeError(f"{sampler.__class__} has no unrepeated version.") + else: + if isinstance(sampler, RandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) + elif isinstance(sampler, SequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) + elif isinstance(sampler, SortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) + raise TypeError(f"{sampler.__class__} has no reproducible version.") \ No newline at end of file diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index d4535bae..1d2c96d9 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): batch_indices = list(batch_indices[:-1]) rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 - batches = (np.array(batches)[batch_indices]).tolist() + batches = (np.array(batches, dtype=object)[batch_indices]).tolist() if last_batches: batches = batches + last_batches return batches @@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" " consumed. ") - states = { - 'seed': self.seed, - 'epoch': self.epoch, - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - 'shuffle': self.shuffle, - 'batch_size': self.batch_size, - 'num_batch_per_bucket': self.num_batch_per_bucket, - 'num_replicas': self.num_replicas - } + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, + 'num_replicas': self.num_replicas, + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} - states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) return states def load_state_dict(self, states: Dict): diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 396e69b2..6ea9cc6b 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,3 +1,10 @@ +__all__ = [ + 'ReproducibleSampler', + 'RandomSampler', + "SortedSampler", + "SequentialSampler" +] + from typing import Dict, List, Union import math import os @@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE from .utils import NumConsumedSamplesArray -__all__ = [ - 'ReproducibleSampler', - 'RandomSampler', - "SortedSampler", - "SequentialSampler" -] - class ReproducibleSampler: """ diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index 80af1787..ddcff37f 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -1,42 +1,10 @@ __all__ = [ - 're_instantiate_sampler', - 'conversion_between_reproducible_and_unrepeated_sampler' + 're_instantiate_sampler' ] from array import array from typing import Sequence from collections import deque -from fastNLP.core.samplers.unrepeated_sampler import * -from fastNLP.core.samplers.reproducible_sampler import * - - -def conversion_between_reproducible_and_unrepeated_sampler(sampler): - """ - 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 - ReproducibleSampler, - - :param sampler: - :return: - """ - assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ - "The sampler must be UnrepeatedSampler or ReproducibleSampler" - if isinstance(sampler, UnrepeatedSampler): - if isinstance(sampler, UnrepeatedRandomSampler): - return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) - elif isinstance(sampler, UnrepeatedSequentialSampler): - return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) - elif isinstance(sampler, UnrepeatedSortedSampler): - return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) - raise TypeError(f"{sampler.__class__} has no unrepeated version.") - else: - if isinstance(sampler, RandomSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) - elif isinstance(sampler, SequentialSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) - elif isinstance(sampler, SortedSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) - raise TypeError(f"{sampler.__class__} has no reproducible version.") - def re_instantiate_sampler(sampler, new_sampler_class=None): all_attributes = vars(sampler) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 1d1c9d16..cceb948f 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -13,7 +13,6 @@ __all__ = [ 'torch_paddle_move_data_to_device', 'torch_move_data_to_device', 'get_fn_arg_names', - 'check_fn_not_empty_params', 'auto_param_call', 'check_user_specific_params', 'dataclass_to_dict', @@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi from .rich_progress import f_rich_progress from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device -from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ +from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index d593f4ee..7af6557f 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -1,3 +1,4 @@ +import functools import inspect from inspect import Parameter import dataclasses @@ -24,10 +25,8 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_GLOBAL_RANK - __all__ = [ 'get_fn_arg_names', - 'check_fn_not_empty_params', 'auto_param_call', 'check_user_specific_params', 'dataclass_to_dict', @@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]: return list(inspect.signature(fn).parameters) -def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: - r""" - 检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; - 用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; - - :param fn: 传入的用以代替 Loop 中 'step' 函数的函数; - :param param_num: 检测的函数的应当的没有默认值的参数的个数; - - :return: bool,表示传入的 `batch_step_fn` 是否正确; - """ - - if fn is None: - return True - if not callable(fn): - return False - else: - params = inspect.signature(fn).parameters - not_default_params = {} - for _name, _param in params.items(): - if _param.default == Parameter.empty: - not_default_params[_name] = _param - return len(not_default_params) == param_num - - def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" @@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 参数值后,再传给 `fn` 进行实际的运算; :param mapping: 一个字典,用来更改其前面的字典的键值; - :param wo_auto_param_call: 是否关闭默认的参数匹配行为; :return: 返回 `fn` 运行的结果; @@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None _kwargs = None for _name, _param in _need_params.items(): if _param.kind == Parameter.VAR_POSITIONAL: - raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") if _param.kind == Parameter.VAR_KEYWORD: _kwargs = (_name, _param) @@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None _default_params[_name] = _param.default if mapping is not None: - assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ + f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." _has_params = {} duplicate_names = [] for arg in args: - assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." + if not isinstance(arg, Dict): + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise TypeError(f"Exception happens when calling {fn_msg}. " + f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") for _name, _value in arg.items(): if mapping is not None and _name in mapping: _name = mapping[_name] @@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None elif _name in _need_params and not (_has_params[_name] is _value): duplicate_names.append(_name) if duplicate_names: - raise ValueError(f"The following key present in several inputs:{duplicate_names}") + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") # 将具有默认值但是没有被输入修改过的参数值传进去; for _name, _value in _default_params.items(): @@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None if len(_has_params) List[List[str]]: + """ + 返回每个 dict 的 keys + + :param args: + :return: + """ + _provided_keys = [] + for arg in args: + _provided_keys.append(list(arg.keys())) + return _provided_keys + + +def _get_fun_msg(fn)->str: + """ + 获取函数的基本信息,帮助报错。 + ex: + print(_get_fun_msg(_get_fun_msg)) + # `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) + + :param callable fn: + :return: + """ + if isinstance(fn, functools.partial): + return _get_fun_msg(fn.func) + try: + fn_name = fn.__qualname__ + str(inspect.signature(fn)) + except: + fn_name = str(fn) + try: + fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' + except: + fp = '' + msg = f'`{fn_name}`' + fp + return msg + + +def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): + """ + 检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 + 进行报错。 + + :param fn: 需要检测的函数,可以是 method 或者 function 。 + :param expected_params: 期待应该支持的参数。 + :param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 + :return: + """ + if fn_name is not None: + assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." + + parameters = list(inspect.signature(fn).parameters.values()) + if inspect.ismethod(fn): + if len(parameters)>0 and parameters[0].name == 'self': + parameters = parameters[1:] # 去掉self + + no_var_param = True # 没有 * 这种参数 + number_param_need_value = 0 + for param in parameters: + if param.kind is param.VAR_POSITIONAL: + no_var_param = False + elif param.kind is param.VAR_KEYWORD: + no_var_param = False + else: + if param.default is param.empty: + number_param_need_value += 1 + + if len(parameters)len(expected_params): + raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" + f" {len(expected_params)} parameters:{expected_params} will be provided.") + + def check_user_specific_params(user_params: Dict, fn: Callable): """ 该函数使用用户的输入来对指定函数的参数进行赋值; @@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): wait_to_success(path.exists) +def get_class_that_defined_method(method): + """ + 给定一个method,返回这个 method 的 class 的对象 + :param method: + :return: + """ + if isinstance(method, functools.partial): + return get_class_that_defined_method(method.func) + if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): + for cls in inspect.getmro(method.__self__.__class__): + if method.__name__ in cls.__dict__: + return cls + method = getattr(method, '__func__', method) # fallback to __qualname__ parsing + if inspect.isfunction(method): + cls = getattr(inspect.getmodule(method), + method.__qualname__.split('.', 1)[0].rsplit('.', 1)[0], + None) + if isinstance(cls, type): + return cls + return getattr(method, '__objclass__', None) # handle special descriptor objects \ No newline at end of file diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 5daee519..2796bb69 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -251,10 +251,10 @@ class DataBundle: def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 .. note:: - ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 @@ -285,7 +285,7 @@ class DataBundle: def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 对DataBundle中所有的dataset使用apply方法 @@ -309,10 +309,10 @@ class DataBundle: def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 .. note:: - ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 0e5915a9..b5db4bd6 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): def process_from_file(self, paths) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param paths: :return: DataBundle diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py index 1448765e..1b6d192a 100644 --- a/fastNLP/io/pipe/construct_graph.py +++ b/fastNLP/io/pipe/construct_graph.py @@ -164,7 +164,7 @@ class GraphBuilderBase: def build_graph_from_file(self, path: str): r""" - 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param path: :return: scipy_sparse_matrix diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index 4916bf09..f974b548 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -33,7 +33,7 @@ class Pipe: def process_from_file(self, paths: str) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param str paths: :return: DataBundle diff --git a/tests/core/utils/test_utils.py b/tests/core/utils/test_utils.py new file mode 100644 index 00000000..a7aeffb1 --- /dev/null +++ b/tests/core/utils/test_utils.py @@ -0,0 +1,187 @@ +from functools import partial + +import pytest + +from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg +from fastNLP.core.metrics import Metric + + + +class TestAutoParamCall: + def test_basic(self): + def fn(x): + return x + x = {'x': 3, 'y': 4} + r = auto_param_call(fn, x) + assert r==3 + + xs = [] + for i in range(10): + xs.append({f'x{i}': i}) + def fn(x0, x1, x2, x3): + return x0 + x1 + x2 + x3 + r = auto_param_call(fn, *xs) + assert r == 0 + 1+ 2+ 3 + + def fn(chongfu1, chongfu2, buChongFu): + pass + with pytest.raises(BaseException) as exc_info: + auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, + {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) + assert 'The following key present in several inputs' in exc_info.value.args[0] + assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] + + # 没用到不报错 + def fn(chongfu1, buChongFu): + pass + auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, + {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) + + # 可以定制signature_fn + def fn1(**kwargs): + kwargs.pop('x') + kwargs.pop('y') + assert len(kwargs)==0 + def fn(x, y): + pass + x = {'x': 3, 'y': 4} + r = auto_param_call(fn1, x, signature_fn=fn) + + # 没提供的时候报错 + def fn(meiti1, meiti2, tigong): + pass + with pytest.raises(BaseException) as exc_info: + auto_param_call(fn, {'tigong':1}) + assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] + + # 默认值替换 + def fn(x, y=100): + return x + y + r = auto_param_call(fn, {'x': 10, 'y': 20}) + assert r==30 + assert auto_param_call(fn, {'x': 10, 'z': 20})==110 + + # 测试mapping的使用 + def fn(x, y=100): + return x + y + r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) + assert r==30 + + # 测试不需要任何参数 + def fn(): + return 1 + assert 1 == auto_param_call(fn, {'x':1}) + + # 测试调用类的方法没问题 + assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) + assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) + + def test_msg(self): + with pytest.raises(BaseException) as exc_info: + auto_param_call(self.call_this, {'x':1}) + assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(call_this_for_auto_param_call, {'x':1}) + assert __file__ in exc_info.value.args[0] + assert 'call_this_for_auto_param_call' in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(self.call_this_two, {'x':1}) + assert __file__ in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) + assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 + + def call_this(self, x, y): + return x + y + + def call_this_two(self, x, y, z=pytest, **kwargs): + return x + y + + def test_metric_auto_param_call(self): + metric = AutoParamCallMetric() + with pytest.raises(BaseException): + auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) + + +class AutoParamCallMetric(Metric): + def update(self, x): + pass + + +def call_this_for_auto_param_call(x, y): + return x + y + + +class TestCheckNumberOfParameters: + def test_validate_every(self): + def validate_every(trainer): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + + # 无默认值,多了报错 + def validate_every(trainer, other): + pass + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + assert "2 parameters" in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 有默认值ok + def validate_every(trainer, other=1): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + + # 参数多了 + def validate_every(trainer): + pass + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) + assert "accepts 1 parameters" in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 使用partial + def validate_every(trainer, other): + pass + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) + assert 'accepts 2 parameters' in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 如果存在 *args 或 *kwargs 不报错多的 + def validate_every(trainer, *args): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) + + def validate_every(trainer, **kwargs): + pass + _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) + + # class 的方法删掉self + class InnerClass: + def demo(self, x): + pass + + def no_param(self): + pass + + def param_kwargs(self, **kwargs): + pass + + inner = InnerClass() + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) + assert 'accepts 1 parameters' in exc_info.value.args[0] + + _check_valid_parameters_number(inner.demo, expected_params=['trainer']) + + +def test_get_fun_msg(): + def demo(x): + pass + + print(_get_fun_msg(_get_fun_msg)) \ No newline at end of file diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index b876c289..9a4af07c 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -2,37 +2,19 @@ import os import sys import __main__ from functools import wraps -import inspect from inspect import ismethod -import functools from copy import deepcopy from io import StringIO import time import numpy as np +from fastNLP.core.utils.utils import get_class_that_defined_method from fastNLP.envs.env import FASTNLP_GLOBAL_RANK from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.log import logger -def get_class_that_defined_method(meth): - if isinstance(meth, functools.partial): - return get_class_that_defined_method(meth.func) - if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): - for cls in inspect.getmro(meth.__self__.__class__): - if meth.__name__ in cls.__dict__: - return cls - meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing - if inspect.isfunction(meth): - cls = getattr(inspect.getmodule(meth), - meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0], - None) - if isinstance(cls, type): - return cls - return getattr(meth, '__objclass__', None) # handle special descriptor objects - - def recover_logger(fn): @wraps(fn) def wrapper(*args, **kwargs):