diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9dc02f3d..8b17f75a 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -128,29 +128,21 @@ class LossBase(object): self.param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} - - # need to wrap inputs in dict. + mapped_pred_dict = {} mapped_target_dict = {} - duplicated = [] - for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): - not_duplicate_flag = 0 - if input_arg in self._reverse_param_map: - mapped_arg = self._reverse_param_map[input_arg] - not_duplicate_flag += 1 - else: - mapped_arg = input_arg + for input_arg, mapped_arg in self._reverse_param_map.items(): if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] - not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] - not_duplicate_flag += 1 - if not_duplicate_flag == 3: - duplicated.append(input_arg) # missing if not self._checked: + duplicated = [] + for input_arg, mapped_arg in self._reverse_param_map.items(): + if input_arg in pred_dict and input_arg in target_dict: + duplicated.append(input_arg) check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) # replace missing. missing = check_res.missing @@ -204,15 +196,12 @@ class LossFunc(LossBase): super(LossFunc, self).__init__() _check_function_or_method(func) + self.get_loss = func if key_map is not None: if not isinstance(key_map, dict): raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") - self.param_map = key_map - if len(kwargs) > 0: - for key, val in kwargs.items(): - self.param_map.update({key: val}) + self._init_param_map(key_map, **kwargs) - self.get_loss = func class CrossEntropyLoss(LossBase): diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 19c33c86..37a94a08 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -223,25 +223,18 @@ class MetricBase(object): # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} - duplicated = [] - for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): - not_duplicate_flag = 0 - if input_arg in self._reverse_param_map: - mapped_arg = self._reverse_param_map[input_arg] - not_duplicate_flag += 1 - else: - mapped_arg = input_arg + for input_arg, mapped_arg in self._reverse_param_map.items(): if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] - not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] - not_duplicate_flag += 1 - if not_duplicate_flag == 3: - duplicated.append(input_arg) # missing if not self._checked: + duplicated = [] + for input_arg, mapped_arg in self._reverse_param_map.items(): + if input_arg in pred_dict and input_arg in target_dict: + duplicated.append(input_arg) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) # only check missing. # replace missing. @@ -411,6 +404,37 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): ] +def _bioes_tag_to_spans(tags, ignore_labels=None): + """ + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (左闭右开区间) + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif bmes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels + ] + + def _bio_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 @@ -471,7 +495,7 @@ class SpanFPreRecMetric(MetricBase): :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 - :param str encoding_type: 目前支持bio, bmes + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 个label :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 @@ -499,6 +523,8 @@ class SpanFPreRecMetric(MetricBase): self.tag_to_span_func = _bio_tag_to_spans elif self.encoding_type == 'bmeso': self.tag_to_span_func = _bmeso_tag_to_spans + elif self.encoding_type == 'bioes': + self.tag_to_span_func = _bioes_tag_to_spans else: raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")