|
@@ -223,25 +223,18 @@ class MetricBase(object): |
|
|
# need to wrap inputs in dict. |
|
|
# need to wrap inputs in dict. |
|
|
mapped_pred_dict = {} |
|
|
mapped_pred_dict = {} |
|
|
mapped_target_dict = {} |
|
|
mapped_target_dict = {} |
|
|
duplicated = [] |
|
|
|
|
|
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): |
|
|
|
|
|
not_duplicate_flag = 0 |
|
|
|
|
|
if input_arg in self._reverse_param_map: |
|
|
|
|
|
mapped_arg = self._reverse_param_map[input_arg] |
|
|
|
|
|
not_duplicate_flag += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
mapped_arg = input_arg |
|
|
|
|
|
|
|
|
for input_arg, mapped_arg in self._reverse_param_map.items(): |
|
|
if input_arg in pred_dict: |
|
|
if input_arg in pred_dict: |
|
|
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] |
|
|
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] |
|
|
not_duplicate_flag += 1 |
|
|
|
|
|
if input_arg in target_dict: |
|
|
if input_arg in target_dict: |
|
|
mapped_target_dict[mapped_arg] = target_dict[input_arg] |
|
|
mapped_target_dict[mapped_arg] = target_dict[input_arg] |
|
|
not_duplicate_flag += 1 |
|
|
|
|
|
if not_duplicate_flag == 3: |
|
|
|
|
|
duplicated.append(input_arg) |
|
|
|
|
|
|
|
|
|
|
|
# missing |
|
|
# missing |
|
|
if not self._checked: |
|
|
if not self._checked: |
|
|
|
|
|
duplicated = [] |
|
|
|
|
|
for input_arg, mapped_arg in self._reverse_param_map.items(): |
|
|
|
|
|
if input_arg in pred_dict and input_arg in target_dict: |
|
|
|
|
|
duplicated.append(input_arg) |
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) |
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) |
|
|
# only check missing. |
|
|
# only check missing. |
|
|
# replace missing. |
|
|
# replace missing. |
|
@@ -411,6 +404,37 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): |
|
|
] |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bioes_tag_to_spans(tags, ignore_labels=None): |
|
|
|
|
|
""" |
|
|
|
|
|
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 |
|
|
|
|
|
返回[('singer', (1, 4))] (左闭右开区间) |
|
|
|
|
|
|
|
|
|
|
|
:param tags: List[str], |
|
|
|
|
|
:param ignore_labels: List[str], 在该list中的label将被忽略 |
|
|
|
|
|
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] |
|
|
|
|
|
""" |
|
|
|
|
|
ignore_labels = set(ignore_labels) if ignore_labels else set() |
|
|
|
|
|
|
|
|
|
|
|
spans = [] |
|
|
|
|
|
prev_bmes_tag = None |
|
|
|
|
|
for idx, tag in enumerate(tags): |
|
|
|
|
|
tag = tag.lower() |
|
|
|
|
|
bmes_tag, label = tag[:1], tag[2:] |
|
|
|
|
|
if bmes_tag in ('b', 's'): |
|
|
|
|
|
spans.append((label, [idx, idx])) |
|
|
|
|
|
elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: |
|
|
|
|
|
spans[-1][1][1] = idx |
|
|
|
|
|
elif bmes_tag == 'o': |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
|
|
spans.append((label, [idx, idx])) |
|
|
|
|
|
prev_bmes_tag = bmes_tag |
|
|
|
|
|
return [(span[0], (span[1][0], span[1][1] + 1)) |
|
|
|
|
|
for span in spans |
|
|
|
|
|
if span[0] not in ignore_labels |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bio_tag_to_spans(tags, ignore_labels=None): |
|
|
def _bio_tag_to_spans(tags, ignore_labels=None): |
|
|
""" |
|
|
""" |
|
|
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 |
|
|
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 |
|
@@ -471,7 +495,7 @@ class SpanFPreRecMetric(MetricBase): |
|
|
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 |
|
|
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 |
|
|
:param str encoding_type: 目前支持bio, bmes |
|
|
|
|
|
|
|
|
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes |
|
|
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 |
|
|
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 |
|
|
个label |
|
|
个label |
|
|
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 |
|
|
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 |
|
@@ -499,6 +523,8 @@ class SpanFPreRecMetric(MetricBase): |
|
|
self.tag_to_span_func = _bio_tag_to_spans |
|
|
self.tag_to_span_func = _bio_tag_to_spans |
|
|
elif self.encoding_type == 'bmeso': |
|
|
elif self.encoding_type == 'bmeso': |
|
|
self.tag_to_span_func = _bmeso_tag_to_spans |
|
|
self.tag_to_span_func = _bmeso_tag_to_spans |
|
|
|
|
|
elif self.encoding_type == 'bioes': |
|
|
|
|
|
self.tag_to_span_func = _bioes_tag_to_spans |
|
|
else: |
|
|
else: |
|
|
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") |
|
|
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") |
|
|
|
|
|
|
|
|