|
|
@@ -162,7 +162,7 @@ class TestCallbackEvents: |
|
|
|
def test_every(self): |
|
|
|
|
|
|
|
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; |
|
|
|
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; |
|
|
|
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; |
|
|
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) |
|
|
|
def _fn(data): |
|
|
|
return data |
|
|
@@ -174,7 +174,7 @@ class TestCallbackEvents: |
|
|
|
_res.append(cu_res) |
|
|
|
assert _res == list(range(100)) |
|
|
|
|
|
|
|
event_state = Events.on_train_begin(every=10) |
|
|
|
event_state = Event.on_train_begin(every=10) |
|
|
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) |
|
|
|
def _fn(data): |
|
|
|
return data |
|
|
@@ -187,7 +187,7 @@ class TestCallbackEvents: |
|
|
|
assert _res == [w - 1 for w in range(10, 101, 10)] |
|
|
|
|
|
|
|
def test_once(self): |
|
|
|
event_state = Events.on_train_begin(once=10) |
|
|
|
event_state = Event.on_train_begin(once=10) |
|
|
|
|
|
|
|
@Filter(once=event_state.once) |
|
|
|
def _fn(data): |
|
|
@@ -220,7 +220,7 @@ def test_callback_events_torch(): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
event_state = Events.on_train_begin(filter_fn=filter_fn) |
|
|
|
event_state = Event.on_train_begin(filter_fn=filter_fn) |
|
|
|
|
|
|
|
@Filter(filter_fn=event_state.filter_fn) |
|
|
|
def _fn(trainer, data): |
|
|
|