Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into pr

tags/v0.4.10
yunfan 5 years ago
parent
commit
f7ebc1ca2c
2 changed files with 18 additions and 1 deletions
  1. +17
    -0
      fastNLP/core/metrics.py
  2. +1
    -1
      fastNLP/core/tester.py

+ 17
- 0
fastNLP/core/metrics.py View File

@@ -118,6 +118,7 @@ class MetricBase(object):
def __init__(self):
self._param_map = {} # key is param in function, value is input param.
self._checked = False
self._metric_name = self.__class__.__name__

@property
def param_map(self):
@@ -135,6 +136,22 @@ class MetricBase(object):
@abstractmethod
def get_metric(self, reset=True):
raise NotImplemented

def set_metric_name(self, name:str):
"""
设置metric的名称,默认是Metric的class name.

:param str name:
:return:
"""
self._metric_name = name

def get_metric_name(self):
"""
返回metric的名称
:return:
"""
return self._metric_name
def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map


+ 1
- 1
fastNLP/core/tester.py View File

@@ -178,7 +178,7 @@ class Tester(object):
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__
metric_name = metric.get_metric_name()
eval_results[metric_name] = eval_result

end_time = time.time()


Loading…
Cancel
Save