|
@@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list |
|
|
from .utils import _get_func_signature |
|
|
from .utils import _get_func_signature |
|
|
from .utils import seq_len_to_mask |
|
|
from .utils import seq_len_to_mask |
|
|
from .vocabulary import Vocabulary |
|
|
from .vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
|
|
|
|
|
|
|
class MetricBase(object): |
|
|
class MetricBase(object): |
|
|
""" |
|
|
""" |
|
@@ -117,10 +117,12 @@ class MetricBase(object): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
self.param_map = {} # key is param in function, value is input param. |
|
|
self.param_map = {} # key is param in function, value is input param. |
|
|
self._checked = False |
|
|
self._checked = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def evaluate(self, *args, **kwargs): |
|
|
def evaluate(self, *args, **kwargs): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def get_metric(self, reset=True): |
|
|
def get_metric(self, reset=True): |
|
|
raise NotImplemented |
|
|
raise NotImplemented |
|
|
|
|
|
|
|
|