From 3310a6cadfadda3b82f6a23c776ae66504a099ae Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 07:03:46 +0000 Subject: [PATCH 1/6] =?UTF-8?q?=E9=83=A8=E5=88=86=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E9=87=8D=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../{test_torch_paddle_driver.py => _test_torch_paddle_driver.py} | 0 .../drivers/torch_paddle_driver/{test_utils.py => _test_utils.py} | 0 .../metrics/{test_accutacy_paddle.py => test_accuracy_paddle.py} | 0 .../{test_torch_paddle_utils.py => _test_torch_paddle_utils.py} | 0 .../mix_modules/{test_mix_module.py => _test_mix_module.py} | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename tests/core/drivers/torch_paddle_driver/{test_torch_paddle_driver.py => _test_torch_paddle_driver.py} (100%) rename tests/core/drivers/torch_paddle_driver/{test_utils.py => _test_utils.py} (100%) rename tests/core/metrics/{test_accutacy_paddle.py => test_accuracy_paddle.py} (100%) rename tests/core/utils/{test_torch_paddle_utils.py => _test_torch_paddle_utils.py} (100%) rename tests/modules/mix_modules/{test_mix_module.py => _test_mix_module.py} (100%) diff --git a/tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py b/tests/core/drivers/torch_paddle_driver/_test_torch_paddle_driver.py similarity index 100% rename from tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py rename to tests/core/drivers/torch_paddle_driver/_test_torch_paddle_driver.py diff --git a/tests/core/drivers/torch_paddle_driver/test_utils.py b/tests/core/drivers/torch_paddle_driver/_test_utils.py similarity index 100% rename from tests/core/drivers/torch_paddle_driver/test_utils.py rename to tests/core/drivers/torch_paddle_driver/_test_utils.py diff --git a/tests/core/metrics/test_accutacy_paddle.py b/tests/core/metrics/test_accuracy_paddle.py similarity index 100% rename from tests/core/metrics/test_accutacy_paddle.py rename to tests/core/metrics/test_accuracy_paddle.py diff --git a/tests/core/utils/test_torch_paddle_utils.py b/tests/core/utils/_test_torch_paddle_utils.py similarity index 100% rename from tests/core/utils/test_torch_paddle_utils.py rename to tests/core/utils/_test_torch_paddle_utils.py diff --git a/tests/modules/mix_modules/test_mix_module.py b/tests/modules/mix_modules/_test_mix_module.py similarity index 100% rename from tests/modules/mix_modules/test_mix_module.py rename to tests/modules/mix_modules/_test_mix_module.py From f2fa05d1360fe02c1c996a0c95d1bb9e5260cd3a Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 07:16:09 +0000 Subject: [PATCH 2/6] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BE=8B=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=20=5FNEED=5FIMPORT=5FPADDLE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/paddle_dataloader/test_fdl.py | 8 +++++--- tests/core/drivers/paddle_driver/test_dist_utils.py | 7 ++++--- tests/core/drivers/paddle_driver/test_fleet.py | 9 +++++---- .../paddle_driver/test_initialize_paddle_driver.py | 5 +++-- tests/core/drivers/paddle_driver/test_single_device.py | 10 ++++++---- tests/core/drivers/paddle_driver/test_utils.py | 7 ++++--- tests/core/drivers/torch_driver/test_single_device.py | 10 ++++++---- tests/core/metrics/test_accuracy_paddle.py | 10 ++++++---- tests/core/utils/test_paddle_utils.py | 4 +++- tests/helpers/datasets/paddle_data.py | 7 +++++-- tests/helpers/models/paddle_model.py | 6 ++++-- 11 files changed, 51 insertions(+), 32 deletions(-) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 83e40610..484b0daa 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -1,10 +1,12 @@ import pytest +import numpy as np from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataset import DataSet -from paddle.io import Dataset, DataLoader -import numpy as np -import paddle +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + from paddle.io import Dataset, DataLoader + import paddle class RandomDataset(Dataset): diff --git a/tests/core/drivers/paddle_driver/test_dist_utils.py b/tests/core/drivers/paddle_driver/test_dist_utils.py index 8b136b3c..da40ad78 100644 --- a/tests/core/drivers/paddle_driver/test_dist_utils.py +++ b/tests/core/drivers/paddle_driver/test_dist_utils.py @@ -14,9 +14,10 @@ from fastNLP.core.drivers.paddle_driver.dist_utils import ( ) from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher from tests.helpers.utils import magic_argv_env_context - -import paddle -import paddle.distributed as dist +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + import paddle.distributed as dist @pytest.mark.paddle class TestDistUtilsTools: diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 40bbe95e..a184bb11 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -13,10 +13,11 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm - -import paddle -import paddle.distributed as dist -from paddle.io import DataLoader, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + import paddle.distributed as dist + from paddle.io import DataLoader, BatchSampler def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index e27f2e0c..e339bbcc 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -5,8 +5,9 @@ from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initiali from fastNLP.envs import get_gpu_count from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.utils import magic_argv_env_context - -import paddle +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle @pytest.mark.paddle def test_incorrect_driver(): diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 326e102a..a00a41f5 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -8,10 +8,12 @@ from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandom from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from fastNLP.core import rank_zero_rm - -import paddle -from paddle.io import DataLoader, BatchSampler -import torch +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +if _NEED_IMPORT_PADDLE: + import paddle + from paddle.io import DataLoader, BatchSampler +if _NEED_IMPORT_TORCH: + import torch ############################################################################ # diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 8db4de2d..4b683c1e 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -7,9 +7,10 @@ from fastNLP.core.drivers.paddle_driver.utils import ( replace_sampler, ) from fastNLP.core.samplers import RandomBatchSampler, RandomSampler - -import paddle -from paddle.io import DataLoader, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + from paddle.io import DataLoader, BatchSampler from tests.helpers.datasets.paddle_data import PaddleNormalDataset diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 29d1fe8e..8c761a95 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -8,10 +8,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDat from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from fastNLP.core import rank_zero_rm - -import torch -from torch.utils.data import DataLoader, BatchSampler -import paddle +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch.utils.data import DataLoader, BatchSampler +if _NEED_IMPORT_PADDLE: + import paddle def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): """ diff --git a/tests/core/metrics/test_accuracy_paddle.py b/tests/core/metrics/test_accuracy_paddle.py index 2d1e59fd..0dc65f1f 100644 --- a/tests/core/metrics/test_accuracy_paddle.py +++ b/tests/core/metrics/test_accuracy_paddle.py @@ -1,12 +1,14 @@ import os import pytest -import paddle -import paddle.distributed -import paddle.distributed.fleet.base.role_maker as role_maker -import paddle.distributed.fleet as fleet from fastNLP.core.metrics import Accuracy from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + import paddle.distributed + import paddle.distributed.fleet.base.role_maker as role_maker + import paddle.distributed.fleet as fleet ############################################################################ diff --git a/tests/core/utils/test_paddle_utils.py b/tests/core/utils/test_paddle_utils.py index e3cb2329..ba9dcf79 100644 --- a/tests/core/utils/test_paddle_utils.py +++ b/tests/core/utils/test_paddle_utils.py @@ -1,7 +1,9 @@ import pytest -import paddle from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle ############################################################################ diff --git a/tests/helpers/datasets/paddle_data.py b/tests/helpers/datasets/paddle_data.py index 17b2d310..0fa8ee83 100644 --- a/tests/helpers/datasets/paddle_data.py +++ b/tests/helpers/datasets/paddle_data.py @@ -1,7 +1,10 @@ -import paddle -from paddle.io import Dataset import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + from paddle.io import Dataset + class PaddleNormalDataset(Dataset): def __init__(self, num_of_data=1000): diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index efa8c0ce..7a897235 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -1,5 +1,7 @@ -import paddle -import paddle.nn as nn +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + import paddle + import paddle.nn as nn class PaddleNormalModel_Classification_1(paddle.nn.Layer): """ From 21699749033dc65fe2f6fd400c54cd14d7e3567b Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 07:20:54 +0000 Subject: [PATCH 3/6] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E7=9A=84import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/paddle_dataloader/test_fdl.py | 2 ++ tests/helpers/datasets/paddle_data.py | 2 ++ tests/helpers/models/paddle_model.py | 7 +++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 484b0daa..abed1e83 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -7,6 +7,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: from paddle.io import Dataset, DataLoader import paddle +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class RandomDataset(Dataset): diff --git a/tests/helpers/datasets/paddle_data.py b/tests/helpers/datasets/paddle_data.py index 0fa8ee83..8a8d39b1 100644 --- a/tests/helpers/datasets/paddle_data.py +++ b/tests/helpers/datasets/paddle_data.py @@ -4,6 +4,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle from paddle.io import Dataset +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class PaddleNormalDataset(Dataset): diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index 7a897235..d2969b8e 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -2,8 +2,11 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle import paddle.nn as nn + from paddle.nn import Layer +else: + from fastNLP.core.utils.dummy_class import DummyClass as Layer -class PaddleNormalModel_Classification_1(paddle.nn.Layer): +class PaddleNormalModel_Classification_1(Layer): """ 基础的paddle分类模型 """ @@ -34,7 +37,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): return {"pred": x, "target": y.reshape((-1,))} -class PaddleNormalModel_Classification_2(paddle.nn.Layer): +class PaddleNormalModel_Classification_2(Layer): """ 基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 """ From 296e7e9f2b315471a68e9c194109082480fc3db5 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 2 May 2022 17:42:38 +0800 Subject: [PATCH 4/6] bug fix for new_collator --- fastNLP/core/collators/new_collator.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py index 1d8636e3..cee713f2 100644 --- a/fastNLP/core/collators/new_collator.py +++ b/fastNLP/core/collators/new_collator.py @@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend -def _get_backend(): +def _get_backend() -> str: """ 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 @@ -57,7 +57,7 @@ def _get_backend(): else: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") return catch_backend[0] # 方式 (2) @@ -66,7 +66,7 @@ def _get_backend(): if catch_backend: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") return catch_backend[0] return 'numpy' @@ -80,7 +80,7 @@ class Collator: 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 - 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad + 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 """ self.unpack_batch_func = None @@ -144,15 +144,18 @@ class Collator: for key in unpack_batch.keys(): if key not in self.input_fields and key not in self.ignore_fields: self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': + self.input_fields[key]['backend'] = self.backend for field_name, setting in self.input_fields.items(): pad_fn = setting.get('pad_fn', None) if callable(pad_fn): padder = pad_fn else: + backend = self.backend if setting['backend'] == 'auto' else setting['backend'] batch_field = unpack_batch.get(field_name) padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], - dtype=setting['dtype'], backend=setting['backend'], + dtype=setting['dtype'], backend=backend, field_name=field_name) self.padders[field_name] = padder if self.batch_data_type == 'l': From 4ba0ff2902f930bd1a961a4213b2015062c33016 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 2 May 2022 18:38:22 +0800 Subject: [PATCH 5/6] fix test --- fastNLP/core/callbacks/callback_events.py | 2 -- fastNLP/core/callbacks/callback_manager.py | 9 ++++++--- tests/core/controllers/test_trainer_paddle.py | 1 - tests/core/controllers/test_trainer_w_evaluator_torch.py | 7 ++----- tests/helpers/callbacks/helper_callbacks.py | 4 +--- tests/helpers/utils.py | 4 ++-- 6 files changed, 11 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 3f3691e3..7252398c 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -4,8 +4,6 @@ from types import DynamicClassAttribute from functools import wraps -import fastNLP - __all__ = [ 'Events', 'EventsList', diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f63c6088..2b8fff60 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -11,6 +11,7 @@ from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.envs import rank_zero_call +from fastNLP.core.utils.utils import _get_fun_msg def _transfer(func): @@ -21,10 +22,12 @@ def _transfer(func): def wrapper(manager, *arg, **kwargs): manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; - returns = [] for callback_fn in manager.callback_fns[func.__name__]: - returns.append(callback_fn(*arg, **kwargs)) - return returns + try: + callback_fn(*arg, **kwargs) + except BaseException as e: + logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") + raise e return wrapper diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 46feafa5..543c0c57 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -11,7 +11,6 @@ from paddle.io import DataLoader from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset -from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @dataclass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index d8dd7d73..891626b5 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -100,17 +100,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, - callbacks, evaluate_every, n_epochs=10, ): + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( if dist.is_initialized(): dist.destroy_process_group() - +@pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_validate_every( @@ -184,9 +183,7 @@ def test_trainer_validate_every( def validate_every(trainer): if trainer.global_forward_batches % 10 == 0: - print(trainer) print("\nfastNLP test validate every.\n") - print(trainer.global_forward_batches) return True trainer = Trainer( diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 4fd5b654..1e0d0e11 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -36,12 +36,10 @@ class RecordMetricCallback(Callback): self.larger_better = larger_better self.metric = None self.metric_threshold = metric_threshold - self.metric_begin_value = None + self.metric_begin_value = float('-inf') if larger_better else float('inf') def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] - if self.metric_begin_value is None: - self.metric_begin_value = self.metric def on_train_end(self, trainer): if self.larger_better: diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 7e02ca0d..463f144d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -30,12 +30,12 @@ def recover_logger(fn): return wrapper -def magic_argv_env_context(fn=None, timeout=600): +def magic_argv_env_context(fn=None, timeout=300): """ 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; 会丢掉 pytest 中的 arg 参数。 - :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; + :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒; :return: """ # 说明是通过 @magic_argv_env_context(timeout=600) 调用; From b1f1743487318ff133daf4072d42883a9ac30dd8 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 10:41:02 +0000 Subject: [PATCH 6/6] =?UTF-8?q?test=5FTrainer=5Fpaddle=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=5FNEED=5FIMPORT=5FPADDLE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/controllers/test_trainer_paddle.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 46feafa5..bc5a590c 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -4,9 +4,11 @@ from dataclasses import dataclass from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from paddle.optimizer import Adam -from paddle.io import DataLoader +if _NEED_IMPORT_PADDLE: + from paddle.optimizer import Adam + from paddle.io import DataLoader from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1