diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index ef6f8b69..007485b2 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -118,6 +118,7 @@ class MetricBase(object): def __init__(self): self._param_map = {} # key is param in function, value is input param. self._checked = False + self._metric_name = self.__class__.__name__ @property def param_map(self): @@ -135,6 +136,22 @@ class MetricBase(object): @abstractmethod def get_metric(self, reset=True): raise NotImplemented + + def set_metric_name(self, name:str): + """ + 设置metric的名称,默认是Metric的class name. + + :param str name: + :return: + """ + self._metric_name = name + + def get_metric_name(self): + """ + 返回metric的名称 + :return: + """ + return self._metric_name def _init_param_map(self, key_map=None, **kwargs): """检查key_map和其他参数map,并将这些映射关系添加到self._param_map diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index ab86fb62..e4d67261 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -178,7 +178,7 @@ class Tester(object): if not isinstance(eval_result, dict): raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " f"`dict`, got {type(eval_result)}") - metric_name = metric.__class__.__name__ + metric_name = metric.get_metric_name() eval_results[metric_name] = eval_result end_time = time.time()