diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index dfb20480..8b51e23c 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -296,6 +296,8 @@ class AccuracyMetric(MetricBase): def bmes_tag_to_spans(tags, ignore_labels=None): """ + 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 + 返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))] :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -315,13 +317,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): else: spans.append((label, [idx, idx])) prev_bmes_tag = bmes_tag - return [(span[0], (span[1][0], span[1][1])) + return [(span[0], (span[1][0], span[1][1]+1)) + for span in spans + if span[0] not in ignore_labels + ] + +def bmeso_tag_to_spans(tags, ignore_labels=None): + """ + 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))] + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: + spans[-1][1][1] = idx + elif bmes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels ] def bio_tag_to_spans(tags, ignore_labels=None): """ + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (特别注意这是左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -343,7 +377,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): else: spans.append((label, [idx, idx])) prev_bio_tag = bio_tag - return [(span[0], (span[1][0], span[1][1])) + return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels ] @@ -390,8 +424,7 @@ class SpanFPreRecMetric(MetricBase): 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ encoding_type = encoding_type.lower() - if encoding_type not in ('bio', 'bmes'): - raise ValueError("Only support 'bio' or 'bmes' type.") + if not isinstance(tag_vocab, Vocabulary): raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) if f_type not in ('micro', 'macro'): @@ -402,6 +435,11 @@ class SpanFPreRecMetric(MetricBase): self.tag_to_span_func = bmes_tag_to_spans elif self.encoding_type == 'bio': self.tag_to_span_func = bio_tag_to_spans + elif self.encoding_type == 'bmeso': + self.tag_to_span_func = bmeso_tag_to_spans + else: + raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") + self.ignore_labels = ignore_labels self.f_type = f_type self.beta = beta diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 50a79d24..987a3527 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -44,10 +44,14 @@ class Vocabulary(object): :param int max_size: set the max number of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None + :param padding: str, padding的字符,默认为。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 + Vocabulary的情况。 + :param unknown: str, unknown的字符,默认为。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 + Vocabulary的情况。 """ - def __init__(self, max_size=None, min_freq=None, unknown='', padding=''): + def __init__(self, max_size=None, min_freq=None, padding='', unknown=''): self.max_size = max_size self.min_freq = min_freq self.word_count = Counter() @@ -97,9 +101,9 @@ class Vocabulary(object): """ self.word2idx = {} if self.padding is not None: - self.word2idx[self.padding] = 0 + self.word2idx[self.padding] = len(self.word2idx) if self.unknown is not None: - self.word2idx[self.unknown] = 1 + self.word2idx[self.unknown] = len(self.word2idx) max_size = min(self.max_size, len(self.word_count)) if self.max_size else None words = self.word_count.most_common(max_size) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 1fcdb7d9..09fce24f 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -877,6 +877,14 @@ class ConllPOSReader(object): class ConllxDataLoader(object): def load(self, path): + """ + + :param path: str,存储数据的路径 + :return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) + 类似于拥有以下结构, 一行为一个instance(sample) + words pos_tags heads labels + ['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] + """ datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index d7db3bf9..e1b68e7a 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 - :param encoding_type: str, 支持"bio", "bmes"。 + :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). start_idx=len(id2label), end_idx=len(id2label)+1。 @@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ - :param encoding_type: str, 支持"BIO", "BMES"。 + :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param from_label: str, 比如"PER", "LOC"等label :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag @@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) return to_tag in ['b', 's', 'end'] else: raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) + elif encoding_type == 'bmeso': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag == 'm': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) + else: raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type))