|
|
@@ -118,6 +118,7 @@ class MetricBase(object): |
|
|
|
def __init__(self): |
|
|
|
self._param_map = {} # key is param in function, value is input param. |
|
|
|
self._checked = False |
|
|
|
self._metric_name = self.__class__.__name__ |
|
|
|
|
|
|
|
@property |
|
|
|
def param_map(self): |
|
|
@@ -135,6 +136,22 @@ class MetricBase(object): |
|
|
|
@abstractmethod |
|
|
|
def get_metric(self, reset=True): |
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
def set_metric_name(self, name:str): |
|
|
|
""" |
|
|
|
设置metric的名称,默认是Metric的class name. |
|
|
|
|
|
|
|
:param str name: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
self._metric_name = name |
|
|
|
|
|
|
|
def get_metric_name(self): |
|
|
|
""" |
|
|
|
返回metric的名称 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
return self._metric_name |
|
|
|
|
|
|
|
def _init_param_map(self, key_map=None, **kwargs): |
|
|
|
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map |
|
|
|