Browse Source

修改了 Trainer.on 的错误提示

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
8c22d0b1f6
5 changed files with 38 additions and 15 deletions
  1. +1
    -1
      fastNLP/core/callbacks/callback.py
  2. +11
    -9
      fastNLP/core/controllers/trainer.py
  3. +0
    -4
      fastNLP/core/utils/utils.py
  4. +1
    -1
      tests/core/callbacks/test_callback_events.py
  5. +25
    -0
      tests/core/controllers/test_trainer_other_things.py

+ 1
- 1
fastNLP/core/callbacks/callback.py View File

@@ -71,7 +71,7 @@ class Callback:
"""
pass

def on_train_batch_begin(self, trainer, batch, indices=None):
def on_train_batch_begin(self, trainer, batch, indices):
r"""
在训练过程中开始具体的一个 batch 前会被触发;



+ 11
- 9
fastNLP/core/controllers/trainer.py View File

@@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger):
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。

"""
self.model = model
self.marker = marker
self.driver_name = driver
if isinstance(driver, str):
self.driver_name = driver
else:
self.driver_name = driver.__class__.__name__
self.device = device
self.fp16 = fp16
self.input_mapping = input_mapping
@@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger):
elif accumulation_steps < 0:
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
self.accumulation_steps = accumulation_steps

# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
self.driver = choose_driver(
model=model,
driver=driver,
@@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger):

def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \
"`Callback` with the same callback time."
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}."
return fn

return wrapper
@@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger):
def data_device(self):
return self.driver.data_device

@property
def model(self):
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`;
return self.driver.model





+ 0
- 4
fastNLP/core/utils/utils.py View File

@@ -44,15 +44,11 @@ __all__ = [
]





def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字;

:param fn: 需要查询的函数;

:return: 一个列表,其中的元素则是查询函数的参数的字符串名字;
"""
return list(inspect.signature(fn).parameters)


+ 1
- 1
tests/core/callbacks/test_callback_events.py View File

@@ -1,7 +1,7 @@
import pytest
from functools import reduce

from fastNLP.core.callbacks.callback_events import Filter
from fastNLP.core.callbacks.callback_events import Events, Filter


class TestFilter:


+ 25
- 0
tests/core/controllers/test_trainer_other_things.py View File

@@ -0,0 +1,25 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from tests.helpers.utils import magic_argv_env_context


@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10))
def fn1(trainer):
pass

@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn2(trainer, batch, indices):
pass

with pytest.raises(AssertionError):
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn3(trainer, batch):
pass





Loading…
Cancel
Save