From 8c22d0b1f61101fa4d32888367e80abfa1923318 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 12 Apr 2022 22:47:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20Trainer.on=20?= =?UTF-8?q?=E7=9A=84=E9=94=99=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 2 +- fastNLP/core/controllers/trainer.py | 20 ++++++++------- fastNLP/core/utils/utils.py | 4 --- tests/core/callbacks/test_callback_events.py | 2 +- .../controllers/test_trainer_other_things.py | 25 +++++++++++++++++++ 5 files changed, 38 insertions(+), 15 deletions(-) create mode 100644 tests/core/controllers/test_trainer_other_things.py diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 4b553a1f..99e47dfe 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -71,7 +71,7 @@ class Callback: """ pass - def on_train_batch_begin(self, trainer, batch, indices=None): + def on_train_batch_begin(self, trainer, batch, indices): r""" 在训练过程中开始具体的一个 batch 前会被触发; diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index af589cbf..6d154770 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger): auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 """ - + self.model = model self.marker = marker - self.driver_name = driver + if isinstance(driver, str): + self.driver_name = driver + else: + self.driver_name = driver.__class__.__name__ self.device = device self.fp16 = fp16 self.input_mapping = input_mapping @@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger): elif accumulation_steps < 0: raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") self.accumulation_steps = accumulation_steps + + # todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧 self.driver = choose_driver( model=model, driver=driver, @@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger): def wrapper(fn: Callable) -> Callable: cls._custom_callbacks[marker].append((event, fn)) - assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ - "callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ - "`Callback` with the same callback time." + callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] + assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ + f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ + f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." return fn return wrapper @@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger): def data_device(self): return self.driver.data_device - @property - def model(self): - # 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; - return self.driver.model diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 5c497606..46211581 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -44,15 +44,11 @@ __all__ = [ ] - - - def get_fn_arg_names(fn: Callable) -> List[str]: r""" 返回一个函数的所有参数的名字; :param fn: 需要查询的函数; - :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; """ return list(inspect.signature(fn).parameters) diff --git a/tests/core/callbacks/test_callback_events.py b/tests/core/callbacks/test_callback_events.py index a71bb07f..8712b469 100644 --- a/tests/core/callbacks/test_callback_events.py +++ b/tests/core/callbacks/test_callback_events.py @@ -1,7 +1,7 @@ import pytest from functools import reduce -from fastNLP.core.callbacks.callback_events import Filter +from fastNLP.core.callbacks.callback_events import Events, Filter class TestFilter: diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py new file mode 100644 index 00000000..6327f4f8 --- /dev/null +++ b/tests/core/controllers/test_trainer_other_things.py @@ -0,0 +1,25 @@ +import pytest + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.callbacks import Events +from tests.helpers.utils import magic_argv_env_context + + +@magic_argv_env_context +def test_trainer_torch_without_evaluator(): + @Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10)) + def fn1(trainer): + pass + + @Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) + def fn2(trainer, batch, indices): + pass + + with pytest.raises(AssertionError): + @Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) + def fn3(trainer, batch): + pass + + + +