Browse Source

添加了对修改过的 Events 的测试

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
8017d8c854
2 changed files with 6 additions and 41 deletions
  1. +6
    -5
      fastNLP/core/callbacks/callback_manager.py
  2. +0
    -36
      tests/core/callbacks/test_callback_event.py

+ 6
- 5
fastNLP/core/callbacks/callback_manager.py View File

@@ -127,11 +127,12 @@ class CallbackManager:
:param callback: 一个具体的 callback 实例;
"""
self.all_callbacks.append(callback)
for name, member in Event.__members__.items():
_fn = getattr(callback, member.value)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)):
self.callback_fns[member.value].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)
for name, member in Event.__dict__.items():
if isinstance(member, staticmethod):
_fn = getattr(callback, name)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)):
self.callback_fns[name].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)

def extract_callback_filter_state(self, callback_name, callback_fn):
r"""


+ 0
- 36
tests/core/callbacks/test_callback_event.py View File

@@ -6,42 +6,6 @@ from fastNLP.core.callbacks.callback_event import Event, Filter


class TestFilter:
def test_params_check(self):
# 顺利通过
_filter1 = Filter(every=10)
_filter2 = Filter(once=10)
_filter3 = Filter(filter_fn=lambda: None)

# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter4 = Filter()
exec_msg = e.value.args[0]
assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter."

# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter5 = Filter(every=10, once=10)
exec_msg = e.value.args[0]
assert exec_msg == "These three values should be only set one."

# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter6 = Filter(every="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument every should be integer and greater than zero"

# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter7 = Filter(once="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument once should be integer and positive"

# 触发 TypeError
with pytest.raises(TypeError) as e:
_filter7 = Filter(filter_fn="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument event_filter should be a callable"

def test_every_filter(self):
# every = 10
@Filter(every=10)


Loading…
Cancel
Save