@@ -4,8 +4,6 @@ from types import DynamicClassAttribute | |||||
from functools import wraps | from functools import wraps | ||||
import fastNLP | |||||
__all__ = [ | __all__ = [ | ||||
'Events', | 'Events', | ||||
'EventsList', | 'EventsList', | ||||
@@ -11,6 +11,7 @@ from .callback import Callback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | |||||
def _transfer(func): | def _transfer(func): | ||||
@@ -21,10 +22,12 @@ def _transfer(func): | |||||
def wrapper(manager, *arg, **kwargs): | def wrapper(manager, *arg, **kwargs): | ||||
manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; | manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; | ||||
returns = [] | |||||
for callback_fn in manager.callback_fns[func.__name__]: | for callback_fn in manager.callback_fns[func.__name__]: | ||||
returns.append(callback_fn(*arg, **kwargs)) | |||||
return returns | |||||
try: | |||||
callback_fn(*arg, **kwargs) | |||||
except BaseException as e: | |||||
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | |||||
raise e | |||||
return wrapper | return wrapper | ||||
@@ -11,7 +11,6 @@ from paddle.io import DataLoader | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@dataclass | @dataclass | ||||
@@ -100,17 +100,16 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | |||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
evaluate_every, | evaluate_every, | ||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_validate_every( | def test_trainer_validate_every( | ||||
@@ -184,9 +183,7 @@ def test_trainer_validate_every( | |||||
def validate_every(trainer): | def validate_every(trainer): | ||||
if trainer.global_forward_batches % 10 == 0: | if trainer.global_forward_batches % 10 == 0: | ||||
print(trainer) | |||||
print("\nfastNLP test validate every.\n") | print("\nfastNLP test validate every.\n") | ||||
print(trainer.global_forward_batches) | |||||
return True | return True | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -36,12 +36,10 @@ class RecordMetricCallback(Callback): | |||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
self.metric = None | self.metric = None | ||||
self.metric_threshold = metric_threshold | self.metric_threshold = metric_threshold | ||||
self.metric_begin_value = None | |||||
self.metric_begin_value = float('-inf') if larger_better else float('inf') | |||||
def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
self.metric = results[self.monitor] | self.metric = results[self.monitor] | ||||
if self.metric_begin_value is None: | |||||
self.metric_begin_value = self.metric | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
if self.larger_better: | if self.larger_better: | ||||
@@ -30,12 +30,12 @@ def recover_logger(fn): | |||||
return wrapper | return wrapper | ||||
def magic_argv_env_context(fn=None, timeout=600): | |||||
def magic_argv_env_context(fn=None, timeout=300): | |||||
""" | """ | ||||
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | ||||
会丢掉 pytest 中的 arg 参数。 | 会丢掉 pytest 中的 arg 参数。 | ||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; | |||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒; | |||||
:return: | :return: | ||||
""" | """ | ||||
# 说明是通过 @magic_argv_env_context(timeout=600) 调用; | # 说明是通过 @magic_argv_env_context(timeout=600) 调用; | ||||