From 46b94002af02e89617187cdcd8fb0bd116bc85f4 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 16 Sep 2022 13:34:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0metric=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=EF=BC=8C=E4=BF=AE=E6=94=B9metric=E7=9A=84=20?= =?UTF-8?q?setattr=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics/metric.py | 3 +++ tests/core/metrics/test_metric.py | 38 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 tests/core/metrics/test_metric.py diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 49714200..c3d23bef 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -107,6 +107,9 @@ class Metric: f"instead of {type(value)}.") if isinstance(value, Element) and key not in self.elements: raise RuntimeError("Please use register_element() function to add Element.") + attrs = self.__dict__ + if key in attrs and isinstance(value, Element): + raise RuntimeError(f'`{key}` has been registered as an attribute, cannot be registered as an Element!') object.__setattr__(self, key, value) # 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning diff --git a/tests/core/metrics/test_metric.py b/tests/core/metrics/test_metric.py new file mode 100644 index 00000000..303990d6 --- /dev/null +++ b/tests/core/metrics/test_metric.py @@ -0,0 +1,38 @@ +import pytest +from fastNLP import Metric + + +class DemoMetric(Metric): + + def __init__(self): + super().__init__(backend='torch') + self.count = 0 + self.register_element('count', 0) + + def evaluate(self): + self.count += 1 + print(self.count) + + +class DemoMetric1(Metric): + + def __init__(self): + super().__init__(backend='torch') + self.register_element('count', 0) + self.count = 2 + + def evaluate(self): + self.count += 1 + return self.count + + +class TestMetric: + + def test_v1(self): + with pytest.raises(RuntimeError): + dmtr = DemoMetric() + dmtr.evaluate() + + def test_v2(self): + dmtr = DemoMetric1() + assert 3 == dmtr.evaluate()