Browse Source

增加metric测试用例,修改metric的 setattr方法

dev0.8.0
MorningForest 2 years ago
parent
commit
46b94002af
2 changed files with 41 additions and 0 deletions
  1. +3
    -0
      fastNLP/core/metrics/metric.py
  2. +38
    -0
      tests/core/metrics/test_metric.py

+ 3
- 0
fastNLP/core/metrics/metric.py View File

@@ -107,6 +107,9 @@ class Metric:
f"instead of {type(value)}.")
if isinstance(value, Element) and key not in self.elements:
raise RuntimeError("Please use register_element() function to add Element.")
attrs = self.__dict__
if key in attrs and isinstance(value, Element):
raise RuntimeError(f'`{key}` has been registered as an attribute, cannot be registered as an Element!')
object.__setattr__(self, key, value)

# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning


+ 38
- 0
tests/core/metrics/test_metric.py View File

@@ -0,0 +1,38 @@
import pytest
from fastNLP import Metric


class DemoMetric(Metric):

def __init__(self):
super().__init__(backend='torch')
self.count = 0
self.register_element('count', 0)

def evaluate(self):
self.count += 1
print(self.count)


class DemoMetric1(Metric):

def __init__(self):
super().__init__(backend='torch')
self.register_element('count', 0)
self.count = 2

def evaluate(self):
self.count += 1
return self.count


class TestMetric:

def test_v1(self):
with pytest.raises(RuntimeError):
dmtr = DemoMetric()
dmtr.evaluate()

def test_v2(self):
dmtr = DemoMetric1()
assert 3 == dmtr.evaluate()

Loading…
Cancel
Save