Browse Source

修复metric和loss在映射时出现重复同名输入时会覆盖的bug

tags/v0.4.10
yh_cc 5 years ago
parent
commit
30b012ac20
2 changed files with 47 additions and 32 deletions
  1. +8
    -19
      fastNLP/core/losses.py
  2. +39
    -13
      fastNLP/core/metrics.py

+ 8
- 19
fastNLP/core/losses.py View File

@@ -128,29 +128,21 @@ class LossBase(object):
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
# need to wrap inputs in dict.

mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
for input_arg, mapped_arg in self._reverse_param_map.items():
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)
# missing
if not self._checked:
duplicated = []
for input_arg, mapped_arg in self._reverse_param_map.items():
if input_arg in pred_dict and input_arg in target_dict:
duplicated.append(input_arg)
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
# replace missing.
missing = check_res.missing
@@ -204,15 +196,12 @@ class LossFunc(LossBase):
super(LossFunc, self).__init__()
_check_function_or_method(func)
self.get_loss = func
if key_map is not None:
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self.param_map = key_map
if len(kwargs) > 0:
for key, val in kwargs.items():
self.param_map.update({key: val})
self._init_param_map(key_map, **kwargs)
self.get_loss = func


class CrossEntropyLoss(LossBase):


+ 39
- 13
fastNLP/core/metrics.py View File

@@ -223,25 +223,18 @@ class MetricBase(object):
# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
for input_arg, mapped_arg in self._reverse_param_map.items():
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)
# missing
if not self._checked:
duplicated = []
for input_arg, mapped_arg in self._reverse_param_map.items():
if input_arg in pred_dict and input_arg in target_dict:
duplicated.append(input_arg)
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
# only check missing.
# replace missing.
@@ -411,6 +404,37 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None):
]


def _bioes_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]


def _bio_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
@@ -471,7 +495,7 @@ class SpanFPreRecMetric(MetricBase):
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。
:param str encoding_type: 目前支持bio, bmes
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
@@ -499,6 +523,8 @@ class SpanFPreRecMetric(MetricBase):
self.tag_to_span_func = _bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans
elif self.encoding_type == 'bioes':
self.tag_to_span_func = _bioes_tag_to_spans
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")


Loading…
Cancel
Save