Browse Source

修改测试例中的Events为Event

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
c8e8ff4a8c
2 changed files with 8 additions and 8 deletions
  1. +4
    -4
      tests/core/callbacks/test_callback_event.py
  2. +4
    -4
      tests/core/controllers/test_trainer_other_things.py

+ 4
- 4
tests/core/callbacks/test_callback_event.py View File

@@ -162,7 +162,7 @@ class TestCallbackEvents:
def test_every(self):

# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
@@ -174,7 +174,7 @@ class TestCallbackEvents:
_res.append(cu_res)
assert _res == list(range(100))

event_state = Events.on_train_begin(every=10)
event_state = Event.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
@@ -187,7 +187,7 @@ class TestCallbackEvents:
assert _res == [w - 1 for w in range(10, 101, 10)]

def test_once(self):
event_state = Events.on_train_begin(once=10)
event_state = Event.on_train_begin(once=10)

@Filter(once=event_state.once)
def _fn(data):
@@ -220,7 +220,7 @@ def test_callback_events_torch():
return True
return False

event_state = Events.on_train_begin(filter_fn=filter_fn)
event_state = Event.on_train_begin(filter_fn=filter_fn)

@Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data):


+ 4
- 4
tests/core/controllers/test_trainer_other_things.py View File

@@ -1,22 +1,22 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from fastNLP.core.callbacks import Event
from tests.helpers.utils import magic_argv_env_context


@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
def fn1(trainer):
pass

@Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things")
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn2(trainer, batch, indices):
pass

with pytest.raises(BaseException):
@Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things")
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn3(trainer, batch):
pass



Loading…
Cancel
Save