diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 49714200..c3d23bef 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -107,6 +107,9 @@ class Metric: f"instead of {type(value)}.") if isinstance(value, Element) and key not in self.elements: raise RuntimeError("Please use register_element() function to add Element.") + attrs = self.__dict__ + if key in attrs and isinstance(value, Element): + raise RuntimeError(f'`{key}` has been registered as an attribute, cannot be registered as an Element!') object.__setattr__(self, key, value) # 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning diff --git a/tests/core/metrics/test_metric.py b/tests/core/metrics/test_metric.py new file mode 100644 index 00000000..303990d6 --- /dev/null +++ b/tests/core/metrics/test_metric.py @@ -0,0 +1,38 @@ +import pytest +from fastNLP import Metric + + +class DemoMetric(Metric): + + def __init__(self): + super().__init__(backend='torch') + self.count = 0 + self.register_element('count', 0) + + def evaluate(self): + self.count += 1 + print(self.count) + + +class DemoMetric1(Metric): + + def __init__(self): + super().__init__(backend='torch') + self.register_element('count', 0) + self.count = 2 + + def evaluate(self): + self.count += 1 + return self.count + + +class TestMetric: + + def test_v1(self): + with pytest.raises(RuntimeError): + dmtr = DemoMetric() + dmtr.evaluate() + + def test_v2(self): + dmtr = DemoMetric1() + assert 3 == dmtr.evaluate()